From 2b9f7c57dcab641c98a38f7a52b5d403e436abf2 Mon Sep 17 00:00:00 2001 From: stephen Date: Thu, 23 Apr 2026 10:45:11 -0700 Subject: [PATCH 1/5] enforce panic-free clippy lints --- Cargo.lock | 2 + Cargo.toml | 5 + .../examples/agent_with_bedrock.rs | 20 ++-- .../examples/document_with_bedrock.rs | 2 +- .../examples/embedding_with_bedrock.rs | 2 +- .../examples/extractor_with_bedrock.rs | 7 +- .../rig-bedrock/examples/image_generator.rs | 12 +- .../examples/image_with_bedrock.rs | 2 +- .../rig-bedrock/examples/rag_with_bedrock.rs | 2 +- .../examples/streaming_with_bedrock.rs | 2 +- .../streaming_with_bedrock_and_tools.rs | 2 +- rig-integrations/rig-bedrock/src/client.rs | 13 ++- .../rig-bedrock/src/completion.rs | 2 +- rig-integrations/rig-bedrock/src/lib.rs | 11 ++ rig-integrations/rig-bedrock/src/streaming.rs | 2 +- .../src/types/completion_request.rs | 80 +++++++------ .../rig-bedrock/src/types/converse_output.rs | 36 +++--- .../rig-bedrock/src/types/errors.rs | 40 +++++-- .../rig-bedrock/src/types/text_to_image.rs | 7 +- .../examples/vector_search_fastembed.rs | 2 +- .../examples/vector_search_fastembed_local.rs | 18 +-- rig-integrations/rig-fastembed/src/lib.rs | 105 +++++++++++++---- rig-integrations/rig-gemini-grpc/build.rs | 13 ++- .../examples/gemini_grpc_agent.rs | 2 +- .../rig-gemini-grpc/src/client.rs | 18 +-- rig-integrations/rig-gemini-grpc/src/lib.rs | 3 +- rig-integrations/rig-helixdb/Cargo.toml | 1 + .../examples/vector_search_helixdb.rs | 19 ++- rig-integrations/rig-helixdb/src/lib.rs | 22 ++-- .../examples/vector_search_local_ann.rs | 2 +- .../examples/vector_search_local_ann_agent.rs | 2 +- .../examples/vector_search_local_enn.rs | 2 +- .../examples/vector_search_s3_ann.rs | 2 +- rig-integrations/rig-lancedb/src/lib.rs | 19 ++- .../rig-lancedb/src/utils/deserializer.rs | 17 ++- .../rig-lancedb/tests/integration_tests.rs | 8 ++ .../examples/vector_search_milvus.rs | 22 ++-- rig-integrations/rig-milvus/src/lib.rs | 2 +- .../examples/vector_search_mongodb.rs | 14 +-- rig-integrations/rig-mongodb/src/lib.rs | 46 ++++++-- .../rig-mongodb/tests/integration_tests.rs | 26 +++++ .../vector_search_movies_add_embeddings.rs | 20 ++-- .../examples/vector_search_movies_consume.rs | 7 +- .../examples/vector_search_simple.rs | 11 +- rig-integrations/rig-neo4j/src/lib.rs | 11 +- .../rig-neo4j/src/vector_index.rs | 2 +- .../rig-neo4j/tests/integration_tests.rs | 22 ++++ .../examples/vector_search_postgres.rs | 13 +-- rig-integrations/rig-postgres/src/lib.rs | 8 +- .../rig-postgres/tests/integration_tests.rs | 22 ++++ .../examples/qdrant_vector_search.rs | 2 +- .../rig-qdrant/tests/integration_tests.rs | 22 ++++ .../examples/s3vectors_vector_search.rs | 8 +- rig-integrations/rig-s3vectors/src/lib.rs | 99 +++++++++------- .../examples/scylladb_vector_search.rs | 27 ++--- rig-integrations/rig-scylladb/src/lib.rs | 4 +- .../rig-scylladb/tests/integration_tests.rs | 8 ++ .../examples/vector_search_sqlite.rs | 2 +- rig-integrations/rig-sqlite/src/lib.rs | 18 ++- .../rig-sqlite/tests/integration_test.rs | 8 ++ .../examples/vector_search_surreal.rs | 16 ++- .../rig-surrealdb/examples/vector_store.rs | 17 ++- rig-integrations/rig-surrealdb/src/lib.rs | 2 + .../examples/vectorize_vector_search.rs | 2 +- .../rig-vectorize/src/client/filter.rs | 30 +++-- .../rig-vectorize/tests/integration_tests.rs | 8 ++ rig-integrations/rig-vertexai/Cargo.toml | 2 +- .../examples/completion_vertexai.rs | 2 +- .../rig-vertexai/examples/tool_vertexai.rs | 6 +- rig-integrations/rig-vertexai/src/client.rs | 70 ++++++----- .../rig-vertexai/src/completion.rs | 1 + rig-integrations/rig-vertexai/src/lib.rs | 11 ++ rig/rig-core/examples/agent.rs | 2 +- rig/rig-core/examples/agent_autonomous.rs | 2 +- .../examples/agent_evaluator_optimizer.rs | 9 +- rig/rig-core/examples/agent_orchestrator.rs | 11 +- .../examples/agent_parallelization.rs | 34 +++--- .../examples/agent_prompt_chaining.rs | 2 +- rig/rig-core/examples/agent_routing.rs | 2 +- rig/rig-core/examples/agent_stream_chat.rs | 2 +- .../examples/agent_with_agent_tool.rs | 15 ++- rig/rig-core/examples/agent_with_context.rs | 2 +- .../examples/agent_with_default_max_turns.rs | 2 +- .../examples/agent_with_echochambers.rs | 5 +- rig/rig-core/examples/agent_with_loaders.rs | 2 +- rig/rig-core/examples/agent_with_tools.rs | 2 +- .../examples/agent_with_tools_otel.rs | 15 ++- rig/rig-core/examples/calculator_chatbot.rs | 55 +++++---- rig/rig-core/examples/chain.rs | 2 +- .../examples/complex_agentic_loop_claude.rs | 10 +- rig/rig-core/examples/custom_vector_store.rs | 5 +- rig/rig-core/examples/debate.rs | 12 +- rig/rig-core/examples/discord_bot.rs | 9 +- rig/rig-core/examples/enum_dispatch.rs | 77 ++++++------ rig/rig-core/examples/extractor.rs | 2 +- rig/rig-core/examples/gemini_deep_research.rs | 2 +- .../examples/gemini_extractor_with_rag.rs | 2 +- .../examples/gemini_video_understanding.rs | 2 +- rig/rig-core/examples/manual_tool_calls.rs | 2 +- rig/rig-core/examples/multi_agent.rs | 19 +-- rig/rig-core/examples/multi_extract.rs | 2 +- rig/rig-core/examples/multi_turn_agent.rs | 55 +++++---- .../examples/multi_turn_agent_extended.rs | 55 +++++---- .../openai_agent_completions_api_otel.rs | 4 +- .../openai_streaming_with_tools_otel.rs | 15 ++- rig/rig-core/examples/pdf_agent.rs | 3 +- rig/rig-core/examples/rag.rs | 2 +- rig/rig-core/examples/rag_dynamic_tools.rs | 28 +++-- .../examples/rag_dynamic_tools_multi_turn.rs | 28 +++-- rig/rig-core/examples/rag_ollama.rs | 2 +- rig/rig-core/examples/reasoning_loop.rs | 58 +++++---- rig/rig-core/examples/request_hook.rs | 2 +- rig/rig-core/examples/rmcp.rs | 4 +- rig/rig-core/examples/sentiment_classifier.rs | 2 +- rig/rig-core/examples/transcription.rs | 89 +++++++------- rig/rig-core/examples/vector_search.rs | 2 +- rig/rig-core/examples/vector_search_cohere.rs | 2 +- rig/rig-core/src/agent/prompt_request/mod.rs | 41 ++++--- .../src/agent/prompt_request/streaming.rs | 43 ++++--- rig/rig-core/src/agent/tool.rs | 4 +- rig/rig-core/src/client/mod.rs | 45 ++++++- rig/rig-core/src/completion/request.rs | 9 +- rig/rig-core/src/embeddings/builder.rs | 15 +-- rig/rig-core/src/embeddings/embedding.rs | 22 ++-- rig/rig-core/src/http_client/mod.rs | 75 +++++------- rig/rig-core/src/http_client/multipart.rs | 13 ++- rig/rig-core/src/http_client/retry.rs | 10 +- rig/rig-core/src/integrations/cli_chatbot.rs | 13 ++- rig/rig-core/src/integrations/discord_bot.rs | 26 +++-- rig/rig-core/src/lib.rs | 10 ++ rig/rig-core/src/model/listing.rs | 3 +- rig/rig-core/src/one_or_many.rs | 13 +++ .../src/providers/anthropic/client.rs | 11 +- .../src/providers/anthropic/completion.rs | 10 +- .../src/providers/anthropic/decoders/line.rs | 32 ++--- .../src/providers/anthropic/decoders/sse.rs | 13 +-- .../src/providers/anthropic/streaming.rs | 2 +- rig/rig-core/src/providers/azure.rs | 72 +++++++----- rig/rig-core/src/providers/chatgpt/mod.rs | 19 +-- rig/rig-core/src/providers/cohere/client.rs | 11 +- .../src/providers/cohere/completion.rs | 29 +++-- .../src/providers/cohere/streaming.rs | 6 +- rig/rig-core/src/providers/copilot/mod.rs | 36 ++++-- rig/rig-core/src/providers/deepseek.rs | 11 +- rig/rig-core/src/providers/galadriel.rs | 16 +-- rig/rig-core/src/providers/gemini/client.rs | 24 ++-- .../src/providers/gemini/completion.rs | 20 +--- .../src/providers/gemini/embedding.rs | 11 +- .../providers/gemini/interactions_api/mod.rs | 36 ++++-- .../gemini/interactions_api/streaming.rs | 4 +- rig/rig-core/src/providers/groq.rs | 28 +++-- .../src/providers/huggingface/client.rs | 12 +- .../src/providers/huggingface/completion.rs | 18 ++- rig/rig-core/src/providers/hyperbolic.rs | 22 ++-- .../openai_chat_completions_compatible.rs | 11 +- rig/rig-core/src/providers/llamafile.rs | 14 +-- rig/rig-core/src/providers/minimax.rs | 42 ++++--- rig/rig-core/src/providers/mira.rs | 21 ++-- rig/rig-core/src/providers/mistral/client.rs | 12 +- .../src/providers/mistral/completion.rs | 4 +- rig/rig-core/src/providers/moonshot.rs | 43 ++++--- rig/rig-core/src/providers/ollama.rs | 38 +++--- rig/rig-core/src/providers/openai/client.rs | 28 ++--- .../src/providers/openai/completion/mod.rs | 18 ++- .../providers/openai/completion/streaming.rs | 5 +- .../src/providers/openai/embedding.rs | 10 +- .../src/providers/openai/image_generation.rs | 9 +- .../src/providers/openai/responses_api/mod.rs | 9 +- .../openai/responses_api/streaming.rs | 27 +++-- .../src/providers/openai/transcription.rs | 16 ++- .../src/providers/openrouter/client.rs | 12 +- .../src/providers/openrouter/completion.rs | 11 +- .../src/providers/openrouter/embedding.rs | 10 +- rig/rig-core/src/providers/perplexity.rs | 12 +- rig/rig-core/src/providers/together/client.rs | 12 +- rig/rig-core/src/providers/voyageai.rs | 12 +- rig/rig-core/src/providers/xai/client.rs | 12 +- rig/rig-core/src/providers/zai.rs | 42 ++++--- rig/rig-core/src/streaming.rs | 18 +-- rig/rig-core/src/telemetry/mod.rs | 14 +-- rig/rig-core/src/tool/mod.rs | 2 +- .../src/vector_store/in_memory_store.rs | 6 +- rig/rig-core/src/vector_store/lsh.rs | 30 +++-- rig/rig-core/tests/anthropic.rs | 8 ++ rig/rig-core/tests/anthropic/agent.rs | 2 +- .../tests/anthropic/default_max_turns.rs | 1 + .../tests/anthropic/empty_end_turn.rs | 6 +- rig/rig-core/tests/anthropic/image.rs | 2 +- rig/rig-core/tests/anthropic/models.rs | 2 +- .../tests/anthropic/multi_turn_streaming.rs | 2 +- .../tests/anthropic/plaintext_document.rs | 4 +- .../tests/anthropic/reasoning_roundtrip.rs | 4 +- .../anthropic/reasoning_tool_roundtrip.rs | 4 +- rig/rig-core/tests/anthropic/streaming.rs | 2 +- .../tests/anthropic/streaming_tools.rs | 2 +- .../tests/anthropic/structured_output.rs | 2 +- rig/rig-core/tests/anthropic/think_tool.rs | 1 + rig/rig-core/tests/azure.rs | 8 ++ rig/rig-core/tests/azure/transcription.rs | 2 +- rig/rig-core/tests/chatgpt.rs | 8 ++ rig/rig-core/tests/cohere.rs | 8 ++ rig/rig-core/tests/cohere/agent.rs | 2 +- rig/rig-core/tests/cohere/streaming.rs | 2 +- rig/rig-core/tests/cohere/streaming_tools.rs | 2 +- rig/rig-core/tests/cohere/tools.rs | 2 +- rig/rig-core/tests/copilot.rs | 8 ++ rig/rig-core/tests/core.rs | 8 ++ rig/rig-core/tests/deepseek.rs | 8 ++ rig/rig-core/tests/deepseek/agent.rs | 2 +- rig/rig-core/tests/deepseek/extractor.rs | 2 +- .../tests/deepseek/extractor_usage.rs | 10 +- rig/rig-core/tests/deepseek/multi_extract.rs | 2 +- .../tests/deepseek/permission_control.rs | 2 + .../tests/deepseek/reasoning_roundtrip.rs | 4 +- .../deepseek/reasoning_tool_roundtrip.rs | 4 +- rig/rig-core/tests/deepseek/request_hook.rs | 1 + rig/rig-core/tests/deepseek/streaming.rs | 2 +- .../tests/deepseek/streaming_tools.rs | 12 +- rig/rig-core/tests/deepseek/tools.rs | 2 +- rig/rig-core/tests/galadriel.rs | 8 ++ rig/rig-core/tests/galadriel/agent.rs | 2 +- .../tests/galadriel/streaming_tools.rs | 2 +- rig/rig-core/tests/gemini.rs | 8 ++ rig/rig-core/tests/gemini/agent.rs | 2 +- rig/rig-core/tests/gemini/embeddings.rs | 4 +- rig/rig-core/tests/gemini/extractor.rs | 4 +- rig/rig-core/tests/gemini/interactions_api.rs | 20 +++- rig/rig-core/tests/gemini/models.rs | 12 +- .../tests/gemini/multi_turn_streaming.rs | 2 +- .../tests/gemini/reasoning_roundtrip.rs | 4 +- .../tests/gemini/reasoning_tool_roundtrip.rs | 4 +- rig/rig-core/tests/gemini/streaming.rs | 3 +- .../streaming_multimodal_tool_results.rs | 2 +- rig/rig-core/tests/gemini/streaming_tools.rs | 9 +- .../tests/gemini/structured_output.rs | 2 +- rig/rig-core/tests/gemini/transcription.rs | 2 +- rig/rig-core/tests/groq.rs | 8 ++ rig/rig-core/tests/groq/agent.rs | 2 +- rig/rig-core/tests/groq/context.rs | 2 +- rig/rig-core/tests/groq/extractor.rs | 2 +- rig/rig-core/tests/groq/extractor_usage.rs | 10 +- rig/rig-core/tests/groq/loaders.rs | 2 +- rig/rig-core/tests/groq/multi_extract.rs | 2 +- rig/rig-core/tests/groq/permission_control.rs | 2 + rig/rig-core/tests/groq/request_hook.rs | 1 + rig/rig-core/tests/groq/streaming.rs | 2 +- .../tests/groq/streaming_reasoning.rs | 2 +- rig/rig-core/tests/groq/streaming_tools.rs | 10 +- rig/rig-core/tests/groq/tools.rs | 2 +- rig/rig-core/tests/groq/transcription.rs | 2 +- rig/rig-core/tests/groq/typed_prompt_tools.rs | 2 +- rig/rig-core/tests/huggingface.rs | 8 ++ rig/rig-core/tests/huggingface/agent.rs | 2 +- rig/rig-core/tests/huggingface/context.rs | 2 +- .../tests/huggingface/image_generation.rs | 2 +- rig/rig-core/tests/huggingface/loaders.rs | 2 +- rig/rig-core/tests/huggingface/streaming.rs | 2 +- rig/rig-core/tests/huggingface/tools.rs | 2 +- .../tests/huggingface/transcription.rs | 2 +- rig/rig-core/tests/hyperbolic.rs | 8 ++ rig/rig-core/tests/hyperbolic/agent.rs | 2 +- .../tests/hyperbolic/audio_generation.rs | 2 +- .../tests/hyperbolic/image_generation.rs | 2 +- rig/rig-core/tests/llamacpp.rs | 8 ++ rig/rig-core/tests/llamafile.rs | 8 ++ rig/rig-core/tests/llamafile/support.rs | 2 +- rig/rig-core/tests/minimax.rs | 8 ++ rig/rig-core/tests/minimax/anthropic.rs | 1 + rig/rig-core/tests/minimax/openai.rs | 1 + rig/rig-core/tests/mira.rs | 8 ++ rig/rig-core/tests/mira/agent.rs | 2 +- rig/rig-core/tests/mira/models.rs | 2 +- rig/rig-core/tests/mira/tools.rs | 2 +- rig/rig-core/tests/mistral.rs | 8 ++ rig/rig-core/tests/mistral/agent.rs | 2 +- rig/rig-core/tests/mistral/embeddings.rs | 2 +- rig/rig-core/tests/mistral/extractor.rs | 2 +- rig/rig-core/tests/mistral/extractor_usage.rs | 10 +- rig/rig-core/tests/mistral/models.rs | 2 +- rig/rig-core/tests/mistral/multi_extract.rs | 2 +- .../tests/mistral/permission_control.rs | 2 + rig/rig-core/tests/mistral/request_hook.rs | 1 + rig/rig-core/tests/mistral/streaming.rs | 4 +- rig/rig-core/tests/mistral/streaming_tools.rs | 8 +- rig/rig-core/tests/mistral/transcription.rs | 2 +- .../tests/mistral/typed_prompt_tools.rs | 2 +- rig/rig-core/tests/moonshot.rs | 8 ++ rig/rig-core/tests/moonshot/agent.rs | 2 +- rig/rig-core/tests/moonshot/anthropic.rs | 1 + rig/rig-core/tests/moonshot/context.rs | 2 +- .../tests/moonshot/reasoning_history.rs | 4 +- rig/rig-core/tests/moonshot/tools.rs | 1 + rig/rig-core/tests/ollama.rs | 8 ++ rig/rig-core/tests/ollama/multimodal.rs | 2 +- rig/rig-core/tests/ollama/pause_control.rs | 4 +- rig/rig-core/tests/ollama/streaming.rs | 1 + rig/rig-core/tests/ollama/streaming_tools.rs | 1 + .../tests/ollama/structured_output.rs | 2 +- rig/rig-core/tests/openai.rs | 8 ++ rig/rig-core/tests/openai/agent.rs | 2 +- rig/rig-core/tests/openai/audio_generation.rs | 2 +- rig/rig-core/tests/openai/completions_api.rs | 25 +++- rig/rig-core/tests/openai/extractor.rs | 2 +- rig/rig-core/tests/openai/extractor_usage.rs | 10 +- rig/rig-core/tests/openai/image_generation.rs | 2 +- rig/rig-core/tests/openai/models.rs | 2 +- rig/rig-core/tests/openai/multi_extract.rs | 2 +- .../tests/openai/permission_control.rs | 2 + .../tests/openai/reasoning_roundtrip.rs | 4 +- .../tests/openai/reasoning_tool_roundtrip.rs | 4 +- rig/rig-core/tests/openai/request_hook.rs | 1 + rig/rig-core/tests/openai/streaming.rs | 4 +- rig/rig-core/tests/openai/streaming_tools.rs | 8 +- .../tests/openai/structured_output.rs | 4 +- rig/rig-core/tests/openai/transcription.rs | 2 +- rig/rig-core/tests/openai/websocket.rs | 2 +- rig/rig-core/tests/openrouter.rs | 8 ++ rig/rig-core/tests/openrouter/agent.rs | 2 +- rig/rig-core/tests/openrouter/extractor.rs | 2 +- .../tests/openrouter/extractor_usage.rs | 10 +- rig/rig-core/tests/openrouter/models.rs | 2 +- .../tests/openrouter/multi_extract.rs | 2 +- rig/rig-core/tests/openrouter/multimodal.rs | 6 +- .../tests/openrouter/permission_control.rs | 2 + .../tests/openrouter/provider_selection.rs | 2 +- .../tests/openrouter/reasoning_roundtrip.rs | 4 +- .../openrouter/reasoning_tool_roundtrip.rs | 4 +- rig/rig-core/tests/openrouter/request_hook.rs | 1 + rig/rig-core/tests/openrouter/streaming.rs | 4 +- .../tests/openrouter/streaming_tools.rs | 8 +- .../tests/openrouter/typed_prompt_tools.rs | 2 +- rig/rig-core/tests/perplexity.rs | 8 ++ rig/rig-core/tests/perplexity/agent.rs | 2 +- rig/rig-core/tests/together.rs | 8 ++ rig/rig-core/tests/together/agent.rs | 2 +- rig/rig-core/tests/together/context.rs | 2 +- rig/rig-core/tests/together/embeddings.rs | 2 +- rig/rig-core/tests/together/streaming.rs | 2 +- .../tests/together/streaming_tools.rs | 2 +- rig/rig-core/tests/together/tools.rs | 2 +- rig/rig-core/tests/voyageai.rs | 8 ++ rig/rig-core/tests/voyageai/embeddings.rs | 2 +- rig/rig-core/tests/xai.rs | 8 ++ rig/rig-core/tests/xai/agent.rs | 2 +- rig/rig-core/tests/xai/audio_generation.rs | 2 +- rig/rig-core/tests/xai/context.rs | 2 +- rig/rig-core/tests/xai/extractor.rs | 2 +- rig/rig-core/tests/xai/extractor_usage.rs | 10 +- rig/rig-core/tests/xai/image_generation.rs | 2 +- rig/rig-core/tests/xai/loaders.rs | 2 +- rig/rig-core/tests/xai/multi_extract.rs | 2 +- rig/rig-core/tests/xai/permission_control.rs | 2 + rig/rig-core/tests/xai/reasoning_roundtrip.rs | 4 +- .../tests/xai/reasoning_tool_roundtrip.rs | 4 +- rig/rig-core/tests/xai/request_hook.rs | 1 + rig/rig-core/tests/xai/streaming.rs | 2 +- rig/rig-core/tests/xai/streaming_tools.rs | 6 +- rig/rig-core/tests/xai/tools.rs | 2 +- rig/rig-core/tests/xai/typed_prompt_tools.rs | 2 +- rig/rig-core/tests/zai.rs | 8 ++ rig/rig-derive/Cargo.toml | 1 + .../examples/rig_tool/async_tool.rs | 10 +- rig/rig-derive/examples/rig_tool/full.rs | 10 +- rig/rig-derive/examples/rig_tool/simple.rs | 8 +- .../examples/rig_tool/with_description.rs | 10 +- rig/rig-derive/src/client.rs | 5 +- rig/rig-derive/src/custom.rs | 9 +- rig/rig-derive/src/lib.rs | 110 +++++++++++++----- rig/rig-derive/tests/calculator.rs | 8 ++ rig/rig-derive/tests/custom_name.rs | 8 ++ rig/rig-derive/tests/visibility.rs | 8 ++ 371 files changed, 2426 insertions(+), 1493 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f01d1a80b..fd5b59b12 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9446,6 +9446,7 @@ dependencies = [ name = "rig-derive" version = "0.1.12" dependencies = [ + "anyhow", "convert_case", "deluxe", "indoc", @@ -9499,6 +9500,7 @@ dependencies = [ name = "rig-helixdb" version = "0.2.4" dependencies = [ + "anyhow", "helix-rs", "rig-core", "serde", diff --git a/Cargo.toml b/Cargo.toml index f98b3728f..e91293fb7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,8 +10,13 @@ exclude = [ [workspace.lints.clippy] dbg_macro = "forbid" +expect_used = "deny" +indexing_slicing = "deny" +panic = "deny" todo = "forbid" unimplemented = "forbid" +unreachable = "deny" +unwrap_used = "deny" [profile.release] lto = true diff --git a/rig-integrations/rig-bedrock/examples/agent_with_bedrock.rs b/rig-integrations/rig-bedrock/examples/agent_with_bedrock.rs index a3b56dc1d..2abb8d7dc 100644 --- a/rig-integrations/rig-bedrock/examples/agent_with_bedrock.rs +++ b/rig-integrations/rig-bedrock/examples/agent_with_bedrock.rs @@ -29,19 +29,18 @@ async fn main() -> Result<(), anyhow::Error> { Ok(()) } -fn client() -> Client { - Client::from_env() +fn client() -> Result { + Ok(Client::from_env()?) } -async fn partial_agent() -> AgentBuilder { - let client = client(); - client.agent(AMAZON_NOVA_LITE) +fn partial_agent() -> Result, anyhow::Error> +{ + Ok(client()?.agent(AMAZON_NOVA_LITE)) } /// Create an AWS Bedrock agent with a system prompt async fn basic() -> Result<(), anyhow::Error> { - let agent = partial_agent() - .await + let agent = partial_agent()? .preamble("Answer with json format only") .build(); @@ -53,8 +52,7 @@ async fn basic() -> Result<(), anyhow::Error> { /// Create an AWS Bedrock with tools async fn tools() -> Result<(), anyhow::Error> { - let calculator_agent = partial_agent() - .await + let calculator_agent = partial_agent()? .preamble("You must only do math by using a tool.") .max_tokens(1024) .tool(common::Adder) @@ -69,7 +67,7 @@ async fn tools() -> Result<(), anyhow::Error> { } async fn context() -> Result<(), anyhow::Error> { - let model = client().completion_model(AMAZON_NOVA_LITE); + let model = client()?.completion_model(AMAZON_NOVA_LITE); // Create an agent with multiple context documents let agent = AgentBuilder::new(model) @@ -92,7 +90,7 @@ async fn context() -> Result<(), anyhow::Error> { /// This example loads in all the rust examples from the rig-core crate and uses them as\\ /// context for the agent async fn loaders() -> Result<(), anyhow::Error> { - let model = client().completion_model(AMAZON_NOVA_LITE); + let model = client()?.completion_model(AMAZON_NOVA_LITE); // Load in all the rust examples let examples = FileLoader::with_glob("rig-core/examples/*.rs")? diff --git a/rig-integrations/rig-bedrock/examples/document_with_bedrock.rs b/rig-integrations/rig-bedrock/examples/document_with_bedrock.rs index be492002d..29ad67b29 100644 --- a/rig-integrations/rig-bedrock/examples/document_with_bedrock.rs +++ b/rig-integrations/rig-bedrock/examples/document_with_bedrock.rs @@ -20,7 +20,7 @@ async fn main() -> Result<(), anyhow::Error> { .with_target(false) .init(); - let client = Client::from_env(); + let client = Client::from_env()?; let agent = client .agent(AMAZON_NOVA_LITE) .preamble("Describe this document") diff --git a/rig-integrations/rig-bedrock/examples/embedding_with_bedrock.rs b/rig-integrations/rig-bedrock/examples/embedding_with_bedrock.rs index 606d2424c..190e1160a 100644 --- a/rig-integrations/rig-bedrock/examples/embedding_with_bedrock.rs +++ b/rig-integrations/rig-bedrock/examples/embedding_with_bedrock.rs @@ -17,7 +17,7 @@ async fn main() -> Result<(), anyhow::Error> { .with_target(false) .init(); - let client = Client::from_env(); + let client = Client::from_env()?; let embeddings = client .embeddings_with_ndims(AMAZON_TITAN_EMBED_TEXT_V2_0, 256) .document(Greetings { diff --git a/rig-integrations/rig-bedrock/examples/extractor_with_bedrock.rs b/rig-integrations/rig-bedrock/examples/extractor_with_bedrock.rs index db163a4aa..1f50cecd7 100644 --- a/rig-integrations/rig-bedrock/examples/extractor_with_bedrock.rs +++ b/rig-integrations/rig-bedrock/examples/extractor_with_bedrock.rs @@ -19,15 +19,12 @@ async fn main() -> Result<(), anyhow::Error> { .with_target(false) .init(); - let client = Client::from_env(); + let client = Client::from_env()?; let data_extractor = client.extractor::(AMAZON_NOVA_LITE).build(); let person = data_extractor .extract("Hello my name is John Doe! I am a software engineer.") .await?; - info!( - "AWS Bedrock: {}", - serde_json::to_string_pretty(&person).unwrap() - ); + info!("AWS Bedrock: {}", serde_json::to_string_pretty(&person)?); Ok(()) } diff --git a/rig-integrations/rig-bedrock/examples/image_generator.rs b/rig-integrations/rig-bedrock/examples/image_generator.rs index a933eacdc..1cc5eeada 100644 --- a/rig-integrations/rig-bedrock/examples/image_generator.rs +++ b/rig-integrations/rig-bedrock/examples/image_generator.rs @@ -10,8 +10,8 @@ use std::path::Path; const DEFAULT_PATH: &str = "./output.png"; #[tokio::main] -async fn main() { - let client = Client::from_env(); +async fn main() -> Result<(), anyhow::Error> { + let client = Client::from_env()?; let image_generation_model = client.image_generation_model(AMAZON_NOVA_CANVAS); let response = image_generation_model .image_generation_request() @@ -19,9 +19,11 @@ async fn main() { .width(512) .height(512) .send() - .await; + .await?; // save image - let mut file = File::create_new(Path::new(&DEFAULT_PATH)).expect("Failed to create file"); - let _ = file.write(&response.unwrap().image); + let mut file = File::create_new(Path::new(DEFAULT_PATH))?; + file.write_all(&response.image)?; + + Ok(()) } diff --git a/rig-integrations/rig-bedrock/examples/image_with_bedrock.rs b/rig-integrations/rig-bedrock/examples/image_with_bedrock.rs index 3fa6724ea..f2f228afb 100644 --- a/rig-integrations/rig-bedrock/examples/image_with_bedrock.rs +++ b/rig-integrations/rig-bedrock/examples/image_with_bedrock.rs @@ -19,7 +19,7 @@ async fn main() -> Result<(), anyhow::Error> { .with_target(false) .init(); - let client = rig_bedrock::client::Client::from_env(); + let client = rig_bedrock::client::Client::from_env()?; let agent = client .agent(AMAZON_NOVA_LITE) .preamble("You are an image describer.") diff --git a/rig-integrations/rig-bedrock/examples/rag_with_bedrock.rs b/rig-integrations/rig-bedrock/examples/rag_with_bedrock.rs index 026f1be7d..80a473a88 100644 --- a/rig-integrations/rig-bedrock/examples/rag_with_bedrock.rs +++ b/rig-integrations/rig-bedrock/examples/rag_with_bedrock.rs @@ -29,7 +29,7 @@ async fn main() -> Result<(), anyhow::Error> { .with_target(false) .init(); - let client = Client::from_env(); + let client = Client::from_env()?; let embedding_model = client.embedding_model_with_ndims(AMAZON_TITAN_EMBED_TEXT_V2_0, 256); // Generate embeddings for the definitions of all the documents using the specified embedding model. diff --git a/rig-integrations/rig-bedrock/examples/streaming_with_bedrock.rs b/rig-integrations/rig-bedrock/examples/streaming_with_bedrock.rs index a028beb7b..a4e82aa6b 100644 --- a/rig-integrations/rig-bedrock/examples/streaming_with_bedrock.rs +++ b/rig-integrations/rig-bedrock/examples/streaming_with_bedrock.rs @@ -6,7 +6,7 @@ use rig_bedrock::{client::Client, completion::AMAZON_NOVA_LITE}; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Create streaming agent with a single context prompt - let agent = Client::from_env() + let agent = Client::from_env()? .agent(AMAZON_NOVA_LITE) .preamble("Be precise and concise.") .temperature(0.5) diff --git a/rig-integrations/rig-bedrock/examples/streaming_with_bedrock_and_tools.rs b/rig-integrations/rig-bedrock/examples/streaming_with_bedrock_and_tools.rs index 32ce2632f..86b27d753 100644 --- a/rig-integrations/rig-bedrock/examples/streaming_with_bedrock_and_tools.rs +++ b/rig-integrations/rig-bedrock/examples/streaming_with_bedrock_and_tools.rs @@ -8,7 +8,7 @@ mod common; async fn main() -> Result<(), anyhow::Error> { tracing_subscriber::fmt().init(); // Create agent with a single context prompt and two tools - let agent = Client::from_env() + let agent = Client::from_env()? .agent(AMAZON_NOVA_LITE) .preamble( "You are a calculator here to help the user perform arithmetic diff --git a/rig-integrations/rig-bedrock/src/client.rs b/rig-integrations/rig-bedrock/src/client.rs index 8f0787b37..956d5d462 100644 --- a/rig-integrations/rig-bedrock/src/client.rs +++ b/rig-integrations/rig-bedrock/src/client.rs @@ -96,21 +96,22 @@ impl Client { impl ProviderClient for Client { type Input = Nothing; + type Error = rig::client::ProviderClientError; - fn from_env() -> Self + fn from_env() -> Result where Self: Sized, { - Client::new() + Ok(Client::new()) } - fn from_val(_: Nothing) -> Self + fn from_val(_: Nothing) -> Result where Self: Sized, { - panic!( - "Please use `Client::from_env` or `Client::with_profile_name(\"aws_profile\")` instead" - ); + Err(rig::client::ProviderClientError::InvalidConfiguration( + "use `Client::from_env()` or `Client::with_profile_name(\"aws_profile\")` instead", + )) } } diff --git a/rig-integrations/rig-bedrock/src/completion.rs b/rig-integrations/rig-bedrock/src/completion.rs index f6f385511..f2e0c9461 100644 --- a/rig-integrations/rig-bedrock/src/completion.rs +++ b/rig-integrations/rig-bedrock/src/completion.rs @@ -262,7 +262,7 @@ impl completion::CompletionModel for CompletionModel { .set_additional_model_request_fields(request.additional_params()) .set_inference_config(request.inference_config()) .set_tool_config(tool_config) - .set_system(request.system_prompt()) + .set_system(request.system_prompt()?) .set_messages(Some(messages)); async move { diff --git a/rig-integrations/rig-bedrock/src/lib.rs b/rig-integrations/rig-bedrock/src/lib.rs index 069f4c131..97143549a 100644 --- a/rig-integrations/rig-bedrock/src/lib.rs +++ b/rig-integrations/rig-bedrock/src/lib.rs @@ -1,3 +1,14 @@ +#![cfg_attr( + test, + allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable + ) +)] + pub mod client; pub mod completion; pub mod embedding; diff --git a/rig-integrations/rig-bedrock/src/streaming.rs b/rig-integrations/rig-bedrock/src/streaming.rs index b2282555e..9bd23d905 100644 --- a/rig-integrations/rig-bedrock/src/streaming.rs +++ b/rig-integrations/rig-bedrock/src/streaming.rs @@ -80,7 +80,7 @@ impl CompletionModel { .set_additional_model_request_fields(request.additional_params()) .set_inference_config(request.inference_config()) .set_tool_config(tool_config) - .set_system(request.system_prompt()) + .set_system(request.system_prompt()?) .set_messages(Some(prompt_with_history)); let response = converse_builder.send().await.map_err(|sdk_error| { diff --git a/rig-integrations/rig-bedrock/src/types/completion_request.rs b/rig-integrations/rig-bedrock/src/types/completion_request.rs index f60f9dae3..900193bb5 100644 --- a/rig-integrations/rig-bedrock/src/types/completion_request.rs +++ b/rig-integrations/rig-bedrock/src/types/completion_request.rs @@ -14,11 +14,11 @@ pub struct AwsCompletionRequest { pub prompt_caching: bool, } -fn cache_point_block() -> CachePointBlock { +fn cache_point_block() -> Result { CachePointBlock::builder() .r#type(CachePointType::Default) .build() - .expect("CachePointBlock type is set") + .map_err(|e| CompletionError::RequestError(e.into())) } impl AwsCompletionRequest { @@ -65,31 +65,33 @@ impl AwsCompletionRequest { if !tools.is_empty() { // Convert rig's ToolChoice to AWS Bedrock ToolChoice use aws_sdk_bedrockruntime::types as aws_bedrock; - let tool_choice = self.inner.tool_choice.as_ref().and_then(|choice| { - match choice { - rig::message::ToolChoice::Auto => Some(aws_bedrock::ToolChoice::Auto( + let tool_choice = self + .inner + .tool_choice + .as_ref() + .map(|choice| match choice { + rig::message::ToolChoice::Auto => Ok(Some(aws_bedrock::ToolChoice::Auto( aws_bedrock::AutoToolChoice::builder().build(), - )), - rig::message::ToolChoice::Required => Some(aws_bedrock::ToolChoice::Any( + ))), + rig::message::ToolChoice::Required => Ok(Some(aws_bedrock::ToolChoice::Any( aws_bedrock::AnyToolChoice::builder().build(), - )), - rig::message::ToolChoice::None => { - // Bedrock doesn't have a "None" option - just omit tool_choice - None - } - rig::message::ToolChoice::Specific { function_names } => { - // Use the first function name for Bedrock's specific tool choice - function_names.first().map(|name| { - aws_bedrock::ToolChoice::Tool( - aws_bedrock::SpecificToolChoice::builder() - .name(name.clone()) - .build() - .expect("Failed to build SpecificToolChoice"), - ) + ))), + rig::message::ToolChoice::None => Ok(None), + rig::message::ToolChoice::Specific { function_names } => function_names + .first() + .map(|name| { + aws_bedrock::SpecificToolChoice::builder() + .name(name.clone()) + .build() + .map(aws_bedrock::ToolChoice::Tool) + .map(Some) + .map_err(|e| CompletionError::RequestError(e.into())) }) - } - } - }); + .transpose() + .map(Option::flatten), + }) + .transpose()? + .flatten(); let config = ToolConfiguration::builder() .set_tools(Some(tools)) @@ -103,7 +105,7 @@ impl AwsCompletionRequest { } } - pub fn system_prompt(&self) -> Option> { + pub fn system_prompt(&self) -> Result>, CompletionError> { let mut system_blocks = Vec::new(); if let Some(system_prompt) = self.inner.preamble.to_owned() @@ -121,12 +123,12 @@ impl AwsCompletionRequest { } if system_blocks.is_empty() { - None + Ok(None) } else { if self.prompt_caching { - system_blocks.push(SystemContentBlock::CachePoint(cache_point_block())); + system_blocks.push(SystemContentBlock::CachePoint(cache_point_block()?)); } - Some(system_blocks) + Ok(Some(system_blocks)) } } @@ -165,7 +167,7 @@ impl AwsCompletionRequest { && let Some(last_msg) = messages.last_mut() { let mut content = last_msg.content.clone(); - content.push(aws_bedrock::ContentBlock::CachePoint(cache_point_block())); + content.push(aws_bedrock::ContentBlock::CachePoint(cache_point_block()?)); *last_msg = aws_bedrock::Message::builder() .role(last_msg.role.clone()) .set_content(Some(content)) @@ -465,14 +467,17 @@ mod tests { }; let aws_request = aws_request(request, false); - let system_prompt = aws_request.system_prompt(); + let system_prompt = aws_request + .system_prompt() + .expect("system prompt should build") + .expect("system prompt should exist"); - assert!(system_prompt.is_some()); - let system_prompt = system_prompt.unwrap(); assert_eq!(system_prompt.len(), 1); assert_eq!( - system_prompt[0], - aws_bedrock::SystemContentBlock::Text("History system instruction".to_string()) + system_prompt.first(), + Some(&aws_bedrock::SystemContentBlock::Text( + "History system instruction".to_string() + )) ); } @@ -486,12 +491,15 @@ mod tests { let aws_request = aws_request(request, true); let system_prompt = aws_request .system_prompt() + .expect("system prompt should build") .expect("system prompt should exist"); assert_eq!(system_prompt.len(), 2); assert_eq!( - system_prompt[0], - aws_bedrock::SystemContentBlock::Text("System prompt".to_string()) + system_prompt.first(), + Some(&aws_bedrock::SystemContentBlock::Text( + "System prompt".to_string() + )) ); assert!(matches!( system_prompt.last(), diff --git a/rig-integrations/rig-bedrock/src/types/converse_output.rs b/rig-integrations/rig-bedrock/src/types/converse_output.rs index 325fdde32..a24c2181e 100644 --- a/rig-integrations/rig-bedrock/src/types/converse_output.rs +++ b/rig-integrations/rig-bedrock/src/types/converse_output.rs @@ -1051,7 +1051,8 @@ impl TryFrom for GuardrailCove value: aws_sdk_bedrockruntime::types::GuardrailCoverage, ) -> Result { Ok(GuardrailCoverage { - text_characters: value.text_characters().map(|x| x.try_into().unwrap()), - images: value.images().map(|x| x.try_into().unwrap()), + text_characters: value.text_characters().map(TryInto::try_into).transpose()?, + images: value.images().map(TryInto::try_into).transpose()?, }) } } @@ -1870,8 +1872,8 @@ impl TryFrom<&aws_sdk_bedrockruntime::types::GuardrailCoverage> for GuardrailCov value: &aws_sdk_bedrockruntime::types::GuardrailCoverage, ) -> Result { Ok(GuardrailCoverage { - text_characters: value.text_characters().map(|x| x.try_into().unwrap()), - images: value.images().map(|x| x.try_into().unwrap()), + text_characters: value.text_characters().map(TryInto::try_into).transpose()?, + images: value.images().map(TryInto::try_into).transpose()?, }) } } @@ -2022,7 +2024,7 @@ impl TryFrom for aws_sdk_bedrockruntime::types::Message { .set_role(role) .set_content(content) .build() - .expect("AWS SDK message conversion should never fail!"); + .map_err(|e| TypeConversionError::new(&e.to_string()))?; Ok(res) } @@ -2619,7 +2621,7 @@ impl TryFrom for aws_sdk_bedrockruntime::types::DocumentBlock { .set_context(context) .set_citations(citations) .build() - .expect("aws document block conversion should not fail"); + .map_err(|e| TypeConversionError::new(&e.to_string()))?; Ok(res) } @@ -2820,7 +2822,7 @@ impl TryFrom for aws_sdk_bedrockruntime::types::S3Location { .set_uri(Some(value.uri)) .set_bucket_owner(value.bucket_owner) .build() - .expect("converting S3 bucket location should never fail"); + .map_err(|e| TypeConversionError::new(&e.to_string()))?; Ok(res) } @@ -2890,7 +2892,7 @@ impl TryFrom for aws_sdk_bedrockruntime::types::CitationsConfig let res = aws_sdk_bedrockruntime::types::CitationsConfig::builder() .set_enabled(Some(value.enabled)) .build() - .expect("Citation config conversion should never fail!"); + .map_err(|e| TypeConversionError::new(&e.to_string()))?; Ok(res) } @@ -2965,7 +2967,7 @@ impl TryFrom .set_format(format) .set_source(source) .build() - .expect("GuardrailConverseImageBlock conversion should never fail!"); + .map_err(|e| TypeConversionError::new(&e.to_string()))?; Ok(res) } } @@ -3119,7 +3121,7 @@ impl TryFrom .set_text(text) .set_qualifiers(qualifiers) .build() - .expect("GuardrailConversionTextBlock conversion should never fail!"); + .map_err(|e| TypeConversionError::new(&e.to_string()))?; Ok(res) } @@ -3214,7 +3216,7 @@ impl TryFrom for aws_sdk_bedrockruntime::types::ImageBlock { .set_format(format) .set_source(source) .build() - .expect("ImageBlock conversion should never fail!"); + .map_err(|e| TypeConversionError::new(&e.to_string()))?; Ok(res) } } @@ -3376,7 +3378,7 @@ impl TryFrom for aws_sdk_bedrockruntime::types::ReasoningTex .set_text(text) .set_signature(signature) .build() - .expect("ReasoningTextBlock conversion should never fail!"); + .map_err(|e| TypeConversionError::new(&e.to_string()))?; Ok(res) } @@ -3416,7 +3418,7 @@ impl TryFrom for aws_sdk_bedrockruntime::types::ToolResultBlock .set_content(content) .set_status(status) .build() - .expect("ToolResultBlock conversion should never fail!"); + .map_err(|e| TypeConversionError::new(&e.to_string()))?; Ok(res) } } @@ -3495,7 +3497,7 @@ impl TryFrom for aws_sdk_bedrockruntime::types::VideoBlock { .set_format(format) .set_source(source) .build() - .expect("VideoBlock conversion should never fail!"); + .map_err(|e| TypeConversionError::new(&e.to_string()))?; Ok(res) } @@ -3635,7 +3637,7 @@ impl TryFrom for aws_sdk_bedrockruntime::types::ToolUseBlock { .set_name(name) .set_input(input) .build() - .expect("ToolUseBlock shouldn't panic!"); + .map_err(|e| TypeConversionError::new(&e.to_string()))?; Ok(res) } diff --git a/rig-integrations/rig-bedrock/src/types/errors.rs b/rig-integrations/rig-bedrock/src/types/errors.rs index 773c7b3f5..840aac228 100644 --- a/rig-integrations/rig-bedrock/src/types/errors.rs +++ b/rig-integrations/rig-bedrock/src/types/errors.rs @@ -66,16 +66,36 @@ pub struct AwsSdkConverseStreamError(pub SdkError for CompletionError { fn from(value: AwsSdkConverseStreamError) -> Self { let error: String = match value.0.into_service_error() { - ConverseStreamError::ModelTimeoutException(e) => e.message.unwrap(), - ConverseStreamError::AccessDeniedException(e) => e.message.unwrap(), - ConverseStreamError::ResourceNotFoundException(e) => e.message.unwrap(), - ConverseStreamError::ThrottlingException(e) => e.message.unwrap(), - ConverseStreamError::ServiceUnavailableException(e) => e.message.unwrap(), - ConverseStreamError::InternalServerException(e) => e.message.unwrap(), - ConverseStreamError::ModelStreamErrorException(e) => e.message.unwrap(), - ConverseStreamError::ValidationException(e) => e.message.unwrap(), - ConverseStreamError::ModelNotReadyException(e) => e.message.unwrap(), - ConverseStreamError::ModelErrorException(e) => e.message.unwrap(), + ConverseStreamError::ModelTimeoutException(e) => e + .message + .unwrap_or_else(|| "Bedrock model timed out".to_string()), + ConverseStreamError::AccessDeniedException(e) => e + .message + .unwrap_or_else(|| "Bedrock access denied".to_string()), + ConverseStreamError::ResourceNotFoundException(e) => e + .message + .unwrap_or_else(|| "Bedrock resource not found".to_string()), + ConverseStreamError::ThrottlingException(e) => e + .message + .unwrap_or_else(|| "Bedrock request throttled".to_string()), + ConverseStreamError::ServiceUnavailableException(e) => e + .message + .unwrap_or_else(|| "Bedrock service unavailable".to_string()), + ConverseStreamError::InternalServerException(e) => e + .message + .unwrap_or_else(|| "Bedrock internal server error".to_string()), + ConverseStreamError::ModelStreamErrorException(e) => e + .message + .unwrap_or_else(|| "Bedrock streaming model error".to_string()), + ConverseStreamError::ValidationException(e) => e + .message + .unwrap_or_else(|| "Bedrock validation error".to_string()), + ConverseStreamError::ModelNotReadyException(e) => e + .message + .unwrap_or_else(|| "Bedrock model not ready".to_string()), + ConverseStreamError::ModelErrorException(e) => e + .message + .unwrap_or_else(|| "Bedrock model error".to_string()), _ => "An unexpected error occurred. Verify Internet connection or AWS keys".into(), }; CompletionError::ProviderError(error) diff --git a/rig-integrations/rig-bedrock/src/types/text_to_image.rs b/rig-integrations/rig-bedrock/src/types/text_to_image.rs index 4a7f61652..d7eea77f6 100644 --- a/rig-integrations/rig-bedrock/src/types/text_to_image.rs +++ b/rig-integrations/rig-bedrock/src/types/text_to_image.rs @@ -112,9 +112,12 @@ impl TryFrom } if let Some(images) = value.to_owned().images { + let image = images.first().ok_or_else(|| { + ImageGenerationError::ResponseError("Bedrock image response was empty".into()) + })?; let data = BASE64_STANDARD - .decode(&images[0]) - .expect("Could not decode image."); + .decode(image) + .map_err(|err| ImageGenerationError::ResponseError(err.to_string()))?; return Ok(Self { image: data, diff --git a/rig-integrations/rig-fastembed/examples/vector_search_fastembed.rs b/rig-integrations/rig-fastembed/examples/vector_search_fastembed.rs index ca38dd19d..3d71f18c3 100644 --- a/rig-integrations/rig-fastembed/examples/vector_search_fastembed.rs +++ b/rig-integrations/rig-fastembed/examples/vector_search_fastembed.rs @@ -23,7 +23,7 @@ async fn main() -> Result<(), anyhow::Error> { // Create OpenAI client let fastembed_client = rig_fastembed::Client::new(); - let embedding_model = fastembed_client.embedding_model(&FastembedModel::AllMiniLML6V2Q); + let embedding_model = fastembed_client.embedding_model(&FastembedModel::AllMiniLML6V2Q)?; let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) .documents(vec![ diff --git a/rig-integrations/rig-fastembed/examples/vector_search_fastembed_local.rs b/rig-integrations/rig-fastembed/examples/vector_search_fastembed_local.rs index ccddd397a..2ec5e72c7 100644 --- a/rig-integrations/rig-fastembed/examples/vector_search_fastembed_local.rs +++ b/rig-integrations/rig-fastembed/examples/vector_search_fastembed_local.rs @@ -1,3 +1,4 @@ +use anyhow::Context; use fastembed::{ EmbeddingModel as FastembedModel, Pooling, TextEmbedding as FastembedTextEmbedding, TokenizerFiles, UserDefinedEmbeddingModel, read_file_to_bytes, @@ -26,26 +27,25 @@ struct WordDefinition { #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Get model info - let test_model_info = - FastembedTextEmbedding::get_model_info(&FastembedModel::AllMiniLML6V2).unwrap(); + let test_model_info = FastembedTextEmbedding::get_model_info(&FastembedModel::AllMiniLML6V2)?; // Set up model directory let model_dir = Path::new("./models/Qdrant--all-MiniLM-L6-v2-onnx/snapshots"); println!("Loading model from: {model_dir:?}"); // Load model files - let onnx_file = - read_file_to_bytes(&model_dir.join("model.onnx")).expect("Could not read model.onnx file"); + let onnx_file = read_file_to_bytes(&model_dir.join("model.onnx")) + .context("Could not read model.onnx file")?; let tokenizer_files = TokenizerFiles { tokenizer_file: read_file_to_bytes(&model_dir.join("tokenizer.json")) - .expect("Could not read tokenizer.json"), + .context("Could not read tokenizer.json")?, config_file: read_file_to_bytes(&model_dir.join("config.json")) - .expect("Could not read config.json"), + .context("Could not read config.json")?, special_tokens_map_file: read_file_to_bytes(&model_dir.join("special_tokens_map.json")) - .expect("Could not read special_tokens_map.json"), + .context("Could not read special_tokens_map.json")?, tokenizer_config_file: read_file_to_bytes(&model_dir.join("tokenizer_config.json")) - .expect("Could not read tokenizer_config.json"), + .context("Could not read tokenizer_config.json")?, }; // Create embedding model @@ -53,7 +53,7 @@ async fn main() -> Result<(), anyhow::Error> { UserDefinedEmbeddingModel::new(onnx_file, tokenizer_files).with_pooling(Pooling::Mean); let embedding_model = - EmbeddingModel::new_from_user_defined(user_defined_model, 384, test_model_info); + EmbeddingModel::new_from_user_defined(user_defined_model, 384, test_model_info)?; // Create documents let documents = vec![ diff --git a/rig-integrations/rig-fastembed/src/lib.rs b/rig-integrations/rig-fastembed/src/lib.rs index 7bec99537..0a6584639 100644 --- a/rig-integrations/rig-fastembed/src/lib.rs +++ b/rig-integrations/rig-fastembed/src/lib.rs @@ -1,4 +1,5 @@ use std::sync::Arc; +use std::{error::Error as StdError, fmt}; pub use fastembed::EmbeddingModel as FastembedModel; use fastembed::{InitOptionsUserDefined, ModelInfo, TextEmbedding, UserDefinedEmbeddingModel}; @@ -15,6 +16,35 @@ use rig::{Embed, embeddings::EmbeddingsBuilder}; #[derive(Clone)] pub struct Client; +#[derive(Debug, Clone)] +pub enum FastembedError { + UnknownModel(FastembedModel), + Initialization(String), + UnsupportedMake, +} + +impl fmt::Display for FastembedError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + FastembedError::UnknownModel(model) => { + write!( + f, + "Failed to resolve FastEmbed model metadata for {model:?}" + ) + } + FastembedError::Initialization(message) => { + write!(f, "Failed to initialize FastEmbed model: {message}") + } + FastembedError::UnsupportedMake => write!( + f, + "`EmbeddingModel::make` is not supported for rig-fastembed; construct models via `Client::embedding_model` or `EmbeddingModel::new_from_user_defined`" + ), + } + } +} + +impl StdError for FastembedError {} + impl Default for Client { fn default() -> Self { Self::new() @@ -41,8 +71,13 @@ impl Client { /// let embedding_model = fastembed_client.embedding_model(&FastembedModel::AllMiniLML6V2Q); /// ``` #[cfg(feature = "hf-hub")] - pub fn embedding_model(&self, model: &FastembedModel) -> EmbeddingModel { - let ndims = TextEmbedding::get_model_info(model).unwrap().dim; + pub fn embedding_model( + &self, + model: &FastembedModel, + ) -> Result { + let ndims = TextEmbedding::get_model_info(model) + .map(|info| info.dim) + .map_err(|_| FastembedError::UnknownModel(model.clone()))?; EmbeddingModel::new(model, ndims) } @@ -54,66 +89,76 @@ impl Client { /// use rig_fastembed::{Client, FastembedModel}; /// /// // Initialize the Fastembed client + /// # async fn run() -> Result<(), Box> { /// let fastembed_client = Client::new(); /// - /// let embeddings = fastembed_client.embeddings(FastembedModel::AllMiniLML6V2Q) - /// .simple_document("doc0", "Hello, world!") - /// .simple_document("doc1", "Goodbye, world!") + /// let embeddings = fastembed_client + /// .embeddings(&FastembedModel::AllMiniLML6V2Q)? + /// .documents(vec![ + /// "Hello, world!".to_string(), + /// "Goodbye, world!".to_string(), + /// ])? /// .build() - /// .await - /// .expect("Failed to embed documents"); + /// .await?; + /// # let _ = embeddings; + /// # Ok(()) + /// # } + /// # let _ = run(); /// ``` #[cfg(feature = "hf-hub")] pub fn embeddings( &self, model: &fastembed::EmbeddingModel, - ) -> EmbeddingsBuilder { - EmbeddingsBuilder::new(self.embedding_model(model)) + ) -> Result, FastembedError> { + Ok(EmbeddingsBuilder::new(self.embedding_model(model)?)) } } #[derive(Clone)] pub struct EmbeddingModel { - embedder: Arc, + embedder: Option>, + init_error: Option, pub model: FastembedModel, ndims: usize, } impl EmbeddingModel { #[cfg(feature = "hf-hub")] - pub fn new(model: &fastembed::EmbeddingModel, ndims: usize) -> Self { + pub fn new(model: &fastembed::EmbeddingModel, ndims: usize) -> Result { let embedder = Arc::new( TextEmbedding::try_new( InitOptions::new(model.to_owned()).with_show_download_progress(true), ) - .unwrap(), + .map_err(|err| FastembedError::Initialization(err.to_string()))?, ); - Self { - embedder, + Ok(Self { + embedder: Some(embedder), + init_error: None, model: model.to_owned(), ndims, - } + }) } pub fn new_from_user_defined( user_defined_model: UserDefinedEmbeddingModel, ndims: usize, model_info: &ModelInfo, - ) -> Self { + ) -> Result { let fastembed_embedding_model = TextEmbedding::try_new_from_user_defined( user_defined_model, InitOptionsUserDefined::default(), ) - .unwrap(); + .map_err(|err| FastembedError::Initialization(err.to_string()))?; let embedder = Arc::new(fastembed_embedding_model); - Self { - embedder, + Ok(Self { + embedder: Some(embedder), + init_error: None, model: model_info.model.to_owned(), ndims, - } + }) } } @@ -122,9 +167,13 @@ impl embeddings::EmbeddingModel for EmbeddingModel { type Client = Client; - /// **PANICS**: FastEmbed models cannot be created via this method, which will panic fn make(_: &Self::Client, _: impl Into, _: Option) -> Self { - panic!("Cannot create a fastembed model via `EmbeddingModel::make`") + Self { + embedder: None, + init_error: Some(FastembedError::UnsupportedMake), + model: FastembedModel::AllMiniLML6V2Q, + ndims: 0, + } } fn ndims(&self) -> usize { @@ -135,10 +184,18 @@ impl embeddings::EmbeddingModel for EmbeddingModel { &self, documents: impl IntoIterator, ) -> Result, EmbeddingError> { + let Some(embedder) = &self.embedder else { + let message = self + .init_error + .as_ref() + .map(ToString::to_string) + .unwrap_or_else(|| "FastEmbed model initialization failed".to_string()); + return Err(EmbeddingError::ProviderError(message)); + }; + let documents_as_strings: Vec = documents.into_iter().collect(); - let documents_as_vec = self - .embedder + let documents_as_vec = embedder .embed(documents_as_strings.clone(), None) .map_err(|err| EmbeddingError::ProviderError(err.to_string()))?; diff --git a/rig-integrations/rig-gemini-grpc/build.rs b/rig-integrations/rig-gemini-grpc/build.rs index 664c3b3f4..dcde9b854 100644 --- a/rig-integrations/rig-gemini-grpc/build.rs +++ b/rig-integrations/rig-gemini-grpc/build.rs @@ -1,15 +1,16 @@ -fn main() { - compile_gemini_protos(); +fn main() -> Result<(), Box> { + compile_gemini_protos() } -fn compile_gemini_protos() { +fn compile_gemini_protos() -> Result<(), Box> { unsafe { - std::env::set_var("PROTOC", protoc_bin_vendored::protoc_bin_path().unwrap()); + std::env::set_var("PROTOC", protoc_bin_vendored::protoc_bin_path()?); } tonic_build::configure() .build_server(false) .build_client(true) .type_attribute(".", "#[derive(serde::Serialize, serde::Deserialize)]") - .compile_protos(&["proto/gemini.proto"], &["proto"]) - .expect("Failed to compile Gemini proto files"); + .compile_protos(&["proto/gemini.proto"], &["proto"])?; + + Ok(()) } diff --git a/rig-integrations/rig-gemini-grpc/examples/gemini_grpc_agent.rs b/rig-integrations/rig-gemini-grpc/examples/gemini_grpc_agent.rs index 836d74b14..34be3bc82 100644 --- a/rig-integrations/rig-gemini-grpc/examples/gemini_grpc_agent.rs +++ b/rig-integrations/rig-gemini-grpc/examples/gemini_grpc_agent.rs @@ -11,7 +11,7 @@ async fn main() -> Result<(), anyhow::Error> { .init(); // Initialize the Google Gemini gRPC client - let client = Client::from_env(); + let client = Client::from_env().map_err(|err| anyhow::anyhow!("{err}"))?; // Create agent with a single context prompt let agent = client diff --git a/rig-integrations/rig-gemini-grpc/src/client.rs b/rig-integrations/rig-gemini-grpc/src/client.rs index 00b783e46..9412a027a 100644 --- a/rig-integrations/rig-gemini-grpc/src/client.rs +++ b/rig-integrations/rig-gemini-grpc/src/client.rs @@ -90,24 +90,18 @@ impl Client { impl ProviderClient for Client { type Input = String; + type Error = Box; /// Create a new Google Gemini gRPC client from the `GEMINI_API_KEY` environment variable. - /// Panics if the environment variable is not set. - fn from_env() -> Self { - let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set"); + fn from_env() -> Result { + let api_key = std::env::var("GEMINI_API_KEY")?; tokio::task::block_in_place(|| { - tokio::runtime::Handle::current() - .block_on(Self::new(api_key)) - .expect("Failed to create Gemini gRPC client") + tokio::runtime::Handle::current().block_on(Self::new(api_key)) }) } - fn from_val(input: Self::Input) -> Self { - tokio::task::block_in_place(|| { - tokio::runtime::Handle::current() - .block_on(Self::new(input)) - .expect("Failed to create Gemini gRPC client") - }) + fn from_val(input: Self::Input) -> Result { + tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(Self::new(input))) } } diff --git a/rig-integrations/rig-gemini-grpc/src/lib.rs b/rig-integrations/rig-gemini-grpc/src/lib.rs index b0a088ae3..4eeb0bae2 100644 --- a/rig-integrations/rig-gemini-grpc/src/lib.rs +++ b/rig-integrations/rig-gemini-grpc/src/lib.rs @@ -5,9 +5,10 @@ //! //! # Example //! ```no_run +//! use rig::client::CompletionClient; //! use rig_gemini_grpc::{Client, completion::GEMINI_2_0_FLASH}; //! -//! # async fn example() -> Result<(), Box> { +//! # async fn example() -> Result<(), Box> { //! let client = Client::new("YOUR_API_KEY").await?; //! //! let completion_model = client.completion_model(GEMINI_2_0_FLASH); diff --git a/rig-integrations/rig-helixdb/Cargo.toml b/rig-integrations/rig-helixdb/Cargo.toml index 9ce194ef7..256529480 100644 --- a/rig-integrations/rig-helixdb/Cargo.toml +++ b/rig-integrations/rig-helixdb/Cargo.toml @@ -17,6 +17,7 @@ serde_json.workspace = true rig-core = { path = "../../rig/rig-core", version = "0.35.0", default-features = false } [dev-dependencies] +anyhow = { workspace = true } tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } [[example]] diff --git a/rig-integrations/rig-helixdb/examples/vector_search_helixdb.rs b/rig-integrations/rig-helixdb/examples/vector_search_helixdb.rs index fa5ee769f..37ddadde1 100644 --- a/rig-integrations/rig-helixdb/examples/vector_search_helixdb.rs +++ b/rig-integrations/rig-helixdb/examples/vector_search_helixdb.rs @@ -27,9 +27,9 @@ impl std::fmt::Display for WordDefinition { } #[tokio::main] -async fn main() { +async fn main() -> Result<(), anyhow::Error> { let openai_model = - rig::providers::openai::Client::from_env().embedding_model(openai::TEXT_EMBEDDING_ADA_002); + rig::providers::openai::Client::from_env()?.embedding_model(openai::TEXT_EMBEDDING_ADA_002); let helixdb_client = HelixDB::new(None, Some(6969), None); // Uses default port 6969 let vector_store = HelixDBVectorStore::new(helixdb_client, openai_model.clone()); @@ -49,13 +49,11 @@ async fn main() { }]; let documents = EmbeddingsBuilder::new(openai_model) - .documents(words) - .unwrap() + .documents(words)? .build() - .await - .expect("Failed to create embeddings"); + .await?; - vector_store.insert_documents(documents).await.unwrap(); + vector_store.insert_documents(documents).await?; let query = "What is a flurbo?"; let vector_req = VectorSearchRequest::builder() @@ -63,10 +61,7 @@ async fn main() { .samples(5) .build(); - let docs = vector_store - .top_n::(vector_req) - .await - .unwrap(); + let docs = vector_store.top_n::(vector_req).await?; for doc in docs { println!( @@ -76,4 +71,6 @@ async fn main() { doc = doc.2 ) } + + Ok(()) } diff --git a/rig-integrations/rig-helixdb/src/lib.rs b/rig-integrations/rig-helixdb/src/lib.rs index d16000ba2..5b487ac5c 100644 --- a/rig-integrations/rig-helixdb/src/lib.rs +++ b/rig-integrations/rig-helixdb/src/lib.rs @@ -10,12 +10,20 @@ use serde::{Deserialize, Serialize}; /// If you are unsure what type to use for the client, `helix_rs::HelixDB` is the typical default. /// /// Usage: -/// ```rust -/// let openai_model = -/// rig::providers::openai::Client::from_env().embedding_model("text-embedding-ada-002"); +/// ```no_run +/// use helix_rs::{HelixDB, HelixDBClient}; +/// use rig::client::{EmbeddingsClient, ProviderClient}; +/// use rig_helixdb::HelixDBVectorStore; +/// +/// # fn example() -> anyhow::Result<()> { +/// let openai_model = rig::providers::openai::Client::from_env()? +/// .embedding_model("text-embedding-ada-002"); /// /// let helixdb_client = HelixDB::new(None, Some(6969), None); /// let vector_store = HelixDBVectorStore::new(helixdb_client, openai_model.clone()); +/// # let _ = vector_store; +/// # Ok(()) +/// # } /// ``` pub struct HelixDBVectorStore { client: C, @@ -88,8 +96,8 @@ where } for (document, embeddings) in documents { - let json_document = serde_json::to_value(&document).unwrap(); - let json_document_as_string = serde_json::to_string(&json_document).unwrap(); + let json_document = serde_json::to_value(&document)?; + let json_document_as_string = serde_json::to_string(&json_document)?; for embedding in embeddings { let embedded_text = embedding.document; @@ -137,7 +145,7 @@ where .client .query::("VectorSearch", &query_input) .await - .unwrap(); + .map_err(|x| VectorStoreError::DatastoreError(x.to_string().into()))?; let docs = result .vec_docs @@ -189,7 +197,7 @@ where .client .query::("VectorSearch", &query_input) .await - .unwrap(); + .map_err(|x| VectorStoreError::DatastoreError(x.to_string().into()))?; // HelixDB gives us the cosine distance, so we need to use `-(cosine_dist - 1)` to get the cosine similarity score. let docs = result diff --git a/rig-integrations/rig-lancedb/examples/vector_search_local_ann.rs b/rig-integrations/rig-lancedb/examples/vector_search_local_ann.rs index a1e296dcf..2c1e7ca7c 100644 --- a/rig-integrations/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-integrations/rig-lancedb/examples/vector_search_local_ann.rs @@ -19,7 +19,7 @@ mod fixture; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). - let openai_client = Client::from_env(); + let openai_client = Client::from_env()?; // Select an embedding model. let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); diff --git a/rig-integrations/rig-lancedb/examples/vector_search_local_ann_agent.rs b/rig-integrations/rig-lancedb/examples/vector_search_local_ann_agent.rs index e244a9fbc..45547c561 100644 --- a/rig-integrations/rig-lancedb/examples/vector_search_local_ann_agent.rs +++ b/rig-integrations/rig-lancedb/examples/vector_search_local_ann_agent.rs @@ -19,7 +19,7 @@ mod fixture; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). - let openai_client = Client::from_env(); + let openai_client = Client::from_env()?; // Select an embedding model. let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); diff --git a/rig-integrations/rig-lancedb/examples/vector_search_local_enn.rs b/rig-integrations/rig-lancedb/examples/vector_search_local_enn.rs index 90a631b14..178ba2bbf 100644 --- a/rig-integrations/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-integrations/rig-lancedb/examples/vector_search_local_enn.rs @@ -18,7 +18,7 @@ mod fixture; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). - let openai_client = Client::from_env(); + let openai_client = Client::from_env()?; // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); diff --git a/rig-integrations/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-integrations/rig-lancedb/examples/vector_search_s3_ann.rs index dfe70713e..92afba965 100644 --- a/rig-integrations/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-integrations/rig-lancedb/examples/vector_search_s3_ann.rs @@ -21,7 +21,7 @@ mod fixture; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). - let openai_client = Client::from_env(); + let openai_client = Client::from_env()?; // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); diff --git a/rig-integrations/rig-lancedb/src/lib.rs b/rig-integrations/rig-lancedb/src/lib.rs index b6a2787aa..61a428a9b 100644 --- a/rig-integrations/rig-lancedb/src/lib.rs +++ b/rig-integrations/rig-lancedb/src/lib.rs @@ -1,3 +1,14 @@ +#![cfg_attr( + test, + allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable + ) +)] + use std::ops::Range; use lancedb::{ @@ -27,7 +38,7 @@ fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError { /// Type on which vector searches can be performed for a lanceDb table. /// # Example -/// ``` +/// ```ignore /// use rig_lancedb::{LanceDbVectorIndex, SearchParams}; /// use rig::providers::openai::{Client, TEXT_EMBEDDING_ADA_002, EmbeddingModel}; /// @@ -363,7 +374,7 @@ where /// Implement the `top_n` method of the `VectorStoreIndex` trait for `LanceDbVectorIndex`. /// # Example - /// ``` + /// ```ignore /// use rig_lancedb::{LanceDbVectorIndex, SearchParams}; /// use rig::providers::openai::{EmbeddingModel, Client, TEXT_EMBEDDING_ADA_002}; /// @@ -377,7 +388,7 @@ where /// let result = vector_store_index /// .top_n::("My boss says I zindle too much, what does that mean?", 1) /// .await?; - /// ``` + /// ```ignore async fn top_n Deserialize<'a> + Send>( &self, req: VectorSearchRequest, @@ -425,7 +436,7 @@ where /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `LanceDbVectorIndex`. /// # Example - /// ``` + /// ```ignore /// use rig_lancedb::{LanceDbVectorIndex, SearchParams}; /// use rig::providers::openai::{Client, TEXT_EMBEDDING_ADA_002, EmbeddingModel}; /// diff --git a/rig-integrations/rig-lancedb/src/utils/deserializer.rs b/rig-integrations/rig-lancedb/src/utils/deserializer.rs index 960c0ab98..600fe3949 100644 --- a/rig-integrations/rig-lancedb/src/utils/deserializer.rs +++ b/rig-integrations/rig-lancedb/src/utils/deserializer.rs @@ -65,7 +65,10 @@ impl RecordBatchDeserializer for RecordBatch { .iter() .enumerate() .fold(serde_json::Map::new(), |mut acc, (col_i, col)| { - acc.insert(column_names[col_i].to_string(), col[row_i].clone()); + if let (Some(name), Some(value)) = (column_names.get(col_i), col.get(row_i)) + { + acc.insert((*name).to_string(), value.clone()); + } acc }) }) @@ -576,7 +579,9 @@ impl RebuildObject for Vec> { self.iter() .enumerate() .fold(serde_json::Map::new(), |mut acc, (col_i, col)| { - acc.insert(col_names[col_i].to_string(), col[row_i].clone()); + if let (Some(name), Some(value)) = (col_names.get(col_i), col.get(row_i)) { + acc.insert((*name).to_string(), value.clone()); + } acc }) }) @@ -585,8 +590,12 @@ impl RebuildObject for Vec> { } fn build_map(&self) -> Vec { - let keys = &self[0]; - let values = &self[1]; + let Some(keys) = self.first() else { + return Vec::new(); + }; + let Some(values) = self.get(1) else { + return Vec::new(); + }; keys.iter() .zip(values) diff --git a/rig-integrations/rig-lancedb/tests/integration_tests.rs b/rig-integrations/rig-lancedb/tests/integration_tests.rs index c4ba5cf1a..e0994a6df 100644 --- a/rig-integrations/rig-lancedb/tests/integration_tests.rs +++ b/rig-integrations/rig-lancedb/tests/integration_tests.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + use serde_json::json; use arrow_array::RecordBatchIterator; diff --git a/rig-integrations/rig-milvus/examples/vector_search_milvus.rs b/rig-integrations/rig-milvus/examples/vector_search_milvus.rs index bab3bf064..7895b8493 100644 --- a/rig-integrations/rig-milvus/examples/vector_search_milvus.rs +++ b/rig-integrations/rig-milvus/examples/vector_search_milvus.rs @@ -26,18 +26,14 @@ impl std::fmt::Display for WordDefinition { #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Create OpenAI client - let openai_client = rig::providers::openai::Client::from_env(); + let openai_client = rig::providers::openai::Client::from_env()?; let model = openai_client.embedding_model(rig::providers::openai::TEXT_EMBEDDING_3_SMALL); - let base_url = std::env::var("MILVUS_BASE_URL").expect("the MILVUS_BASE_URL env var to exist"); - let collection_name = std::env::var("MILVUS_COLLECTION_NAME") - .expect("the MILVUS_COLLECTION_NAME env var to exist"); - let database_name = - std::env::var("MILVUS_DATABASE_NAME").expect("the MILVUS_DATABASE_NAME env var to exist"); - let milvus_user = - std::env::var("MILVUS_USERNAME").expect("the MILVUS_USERNAME env var to exist"); - let milvus_password = - std::env::var("MILVUS_PASSWORD").expect("the MILVUS_PASSWORD env var to exist"); + let base_url = std::env::var("MILVUS_BASE_URL")?; + let collection_name = std::env::var("MILVUS_COLLECTION_NAME")?; + let database_name = std::env::var("MILVUS_DATABASE_NAME")?; + let milvus_user = std::env::var("MILVUS_USERNAME")?; + let milvus_password = std::env::var("MILVUS_PASSWORD")?; let vector_store = rig_milvus::MilvusVectorStore::new(model.clone(), base_url, database_name, collection_name) @@ -59,11 +55,9 @@ async fn main() -> Result<(), anyhow::Error> { }]; let documents = EmbeddingsBuilder::new(model.clone()) - .documents(words) - .unwrap() + .documents(words)? .build() - .await - .expect("Failed to create embeddings"); + .await?; vector_store.insert_documents(documents).await?; diff --git a/rig-integrations/rig-milvus/src/lib.rs b/rig-integrations/rig-milvus/src/lib.rs index f529bf616..406479247 100644 --- a/rig-integrations/rig-milvus/src/lib.rs +++ b/rig-integrations/rig-milvus/src/lib.rs @@ -207,7 +207,7 @@ where let insert_request = self.create_insert_request(data); - let body = serde_json::to_string(&insert_request).unwrap(); + let body = serde_json::to_string(&insert_request)?; let res = client.body(body).send().await?; diff --git a/rig-integrations/rig-mongodb/examples/vector_search_mongodb.rs b/rig-integrations/rig-mongodb/examples/vector_search_mongodb.rs index 6283569f9..3f80bdde6 100644 --- a/rig-integrations/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-integrations/rig-mongodb/examples/vector_search_mongodb.rs @@ -49,17 +49,13 @@ where #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client - let openai_client = Client::from_env(); + let openai_client = Client::from_env()?; // Initialize MongoDB client - let mongodb_connection_string = - env::var("MONGODB_CONNECTION_STRING").expect("MONGODB_CONNECTION_STRING not set"); - let options = ClientOptions::parse(mongodb_connection_string) - .await - .expect("MongoDB connection string should be valid"); - - let mongodb_client = - MongoClient::with_options(options).expect("MongoDB client options should be valid"); + let mongodb_connection_string = env::var("MONGODB_CONNECTION_STRING")?; + let options = ClientOptions::parse(mongodb_connection_string).await?; + + let mongodb_client = MongoClient::with_options(options)?; // Initialize MongoDB vector store let collection: Collection = mongodb_client diff --git a/rig-integrations/rig-mongodb/src/lib.rs b/rig-integrations/rig-mongodb/src/lib.rs index 0d7f9eada..f95f6edfb 100644 --- a/rig-integrations/rig-mongodb/src/lib.rs +++ b/rig-integrations/rig-mongodb/src/lib.rs @@ -64,11 +64,11 @@ fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError { /// A vector index for a MongoDB collection. /// # Example -/// ```rust +/// ```no_run /// use rig_mongodb::{MongoDbVectorIndex, SearchParams}; /// use rig::{providers::openai, vector_store::{VectorStoreIndex, VectorSearchRequest}, client::{ProviderClient, EmbeddingsClient}}; /// -/// # tokio_test::block_on(async { +/// # async fn example() -> anyhow::Result<()> { /// #[derive(serde::Deserialize, serde::Serialize, Debug)] /// struct WordDefinition { /// #[serde(rename = "_id")] @@ -78,7 +78,7 @@ fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError { /// } /// /// let mongodb_client = mongodb::Client::with_uri_str("mongodb://localhost:27017").await?; // <-- replace with your mongodb uri. -/// let openai_client = openai::Client::from_env(); +/// let openai_client = openai::Client::from_env()?; /// /// let collection = mongodb_client.database("db").collection::(""); // <-- replace with your mongodb collection. /// @@ -94,15 +94,15 @@ fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError { /// let req = VectorSearchRequest::builder() /// .query("My boss says I zindle too much, what does that mean?") /// .samples(1) -/// .build() -/// .unwrap(); +/// .build(); /// /// // Query the index /// let definitions = index /// .top_n::(req) /// .await?; -/// # Ok::<_, anyhow::Error>(()) -/// # }).unwrap() +/// # Ok(()) +/// # } +/// # let _ = example(); /// ``` pub struct MongoDbVectorIndex where @@ -378,8 +378,20 @@ where let mut results = Vec::new(); while let Some(doc) = cursor.next().await { let doc = doc.map_err(mongodb_to_rig_error)?; - let score = doc.get("score").expect("score").as_f64().expect("f64"); - let id = doc.get("_id").expect("_id").to_string(); + let score = doc + .get("score") + .and_then(serde_json::Value::as_f64) + .ok_or_else(|| { + VectorStoreError::DatastoreError(Box::new(std::io::Error::other( + "MongoDB vector search result missing numeric score", + ))) + })?; + let id = doc.get("_id").ok_or_else(|| { + VectorStoreError::DatastoreError(Box::new(std::io::Error::other( + "MongoDB vector search result missing _id", + ))) + })?; + let id = id.to_string(); let doc_t: T = serde_json::from_value(doc).map_err(VectorStoreError::JsonError)?; results.push((score, id, doc_t)); } @@ -423,8 +435,20 @@ where let mut results = Vec::new(); while let Some(doc) = cursor.next().await { let doc = doc.map_err(mongodb_to_rig_error)?; - let score = doc.get("score").expect("score").as_f64().expect("f64"); - let id = doc.get("_id").expect("_id").to_string(); + let score = doc + .get("score") + .and_then(serde_json::Value::as_f64) + .ok_or_else(|| { + VectorStoreError::DatastoreError(Box::new(std::io::Error::other( + "MongoDB vector search result missing numeric score", + ))) + })?; + let id = doc.get("_id").ok_or_else(|| { + VectorStoreError::DatastoreError(Box::new(std::io::Error::other( + "MongoDB vector search result missing _id", + ))) + })?; + let id = id.to_string(); results.push((score, id)); } diff --git a/rig-integrations/rig-mongodb/tests/integration_tests.rs b/rig-integrations/rig-mongodb/tests/integration_tests.rs index 5ce659e2e..5d839a236 100644 --- a/rig-integrations/rig-mongodb/tests/integration_tests.rs +++ b/rig-integrations/rig-mongodb/tests/integration_tests.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + use futures::StreamExt; use mongodb::{ Collection, SearchIndexModel, @@ -35,8 +43,22 @@ const DATABASE_NAME: &str = "rig"; const USERNAME: &str = "riguser"; const PASSWORD: &str = "rigpassword"; +fn skip_if_docker_unavailable(test_name: &str) -> bool { + let docker_socket = std::path::Path::new("/var/run/docker.sock"); + if std::env::var_os("DOCKER_HOST").is_some() || docker_socket.exists() { + return false; + } + + eprintln!("skipping {test_name}: Docker is unavailable"); + true +} + #[tokio::test] async fn vector_search_test() { + if skip_if_docker_unavailable("vector_search_test") { + return; + } + // Setup mock openai API let server = httpmock::MockServer::start(); @@ -179,6 +201,10 @@ async fn vector_search_test() { #[tokio::test] async fn insert_documents_test() { + if skip_if_docker_unavailable("insert_documents_test") { + return; + } + // Setup mock openai API let server = httpmock::MockServer::start(); diff --git a/rig-integrations/rig-neo4j/examples/vector_search_movies_add_embeddings.rs b/rig-integrations/rig-neo4j/examples/vector_search_movies_add_embeddings.rs index b0d9ea63e..423d9a9e4 100644 --- a/rig-integrations/rig-neo4j/examples/vector_search_movies_add_embeddings.rs +++ b/rig-integrations/rig-neo4j/examples/vector_search_movies_add_embeddings.rs @@ -36,12 +36,12 @@ const INDEX_NAME: &str = "moviePlots"; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client - let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); - let openai_client: Client = Client::new(&openai_api_key).unwrap(); + let openai_api_key = env::var("OPENAI_API_KEY")?; + let openai_client: Client = Client::new(&openai_api_key)?; - let neo4j_uri = env::var("NEO4J_URI").expect("NEO4J_URI not set"); - let neo4j_username = env::var("NEO4J_USERNAME").expect("NEO4J_USERNAME not set"); - let neo4j_password = env::var("NEO4J_PASSWORD").expect("NEO4J_PASSWORD not set"); + let neo4j_uri = env::var("NEO4J_URI")?; + let neo4j_username = env::var("NEO4J_USERNAME")?; + let neo4j_password = env::var("NEO4J_PASSWORD")?; let neo4j_client = Neo4jClient::connect(&neo4j_uri, &neo4j_username, &neo4j_password).await?; @@ -147,11 +147,15 @@ async fn main() -> Result<(), anyhow::Error> { } async fn import_batch(graph: &Graph, nodes: &[Movie], batch_n: i32) -> Result<(), anyhow::Error> { - let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); + let openai_api_key = env::var("OPENAI_API_KEY")?; let to_encode_list: Vec = nodes .iter() - .map(|node| node.to_encode.clone().unwrap()) - .collect(); + .map(|node| { + node.to_encode + .clone() + .ok_or_else(|| anyhow::anyhow!("movie payload missing text to encode")) + }) + .collect::>()?; graph.run( Query::new(format!( diff --git a/rig-integrations/rig-neo4j/examples/vector_search_movies_consume.rs b/rig-integrations/rig-neo4j/examples/vector_search_movies_consume.rs index 8b7e1764d..87196a5fb 100644 --- a/rig-integrations/rig-neo4j/examples/vector_search_movies_consume.rs +++ b/rig-integrations/rig-neo4j/examples/vector_search_movies_consume.rs @@ -37,8 +37,8 @@ async fn main() -> Result<(), anyhow::Error> { const INDEX_NAME: &str = "moviePlotsEmbedding"; // Initialize OpenAI client - let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); - let openai_client: Client = Client::new(&openai_api_key).unwrap(); + let openai_api_key = env::var("OPENAI_API_KEY")?; + let openai_client: Client = Client::new(&openai_api_key)?; let neo4j_uri = "neo4j+s://demo.neo4jlabs.com:7687"; let neo4j_username = "recommendations"; @@ -50,8 +50,7 @@ async fn main() -> Result<(), anyhow::Error> { .user(neo4j_username) .password(neo4j_password) .db("recommendations") - .build() - .unwrap(), + .build()?, ) .await?; diff --git a/rig-integrations/rig-neo4j/examples/vector_search_simple.rs b/rig-integrations/rig-neo4j/examples/vector_search_simple.rs index b84a4ee83..38622f7b4 100644 --- a/rig-integrations/rig-neo4j/examples/vector_search_simple.rs +++ b/rig-integrations/rig-neo4j/examples/vector_search_simple.rs @@ -28,12 +28,12 @@ pub struct Word { #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client - let openai_client = Client::from_env(); + let openai_client = Client::from_env()?; // Initialize Neo4j client - let neo4j_uri = env::var("NEO4J_URI").expect("NEO4J_URI not set"); - let neo4j_username = env::var("NEO4J_USERNAME").expect("NEO4J_USERNAME not set"); - let neo4j_password = env::var("NEO4J_PASSWORD").expect("NEO4J_PASSWORD not set"); + let neo4j_uri = env::var("NEO4J_URI")?; + let neo4j_username = env::var("NEO4J_USERNAME")?; + let neo4j_password = env::var("NEO4J_PASSWORD")?; let neo4j_client = Neo4jClient::connect(&neo4j_uri, &neo4j_username, &neo4j_password).await?; @@ -77,8 +77,7 @@ async fn main() -> Result<(), anyhow::Error> { }) .buffer_unordered(3) .try_collect::>() - .await - .unwrap(); + .await?; // Create a vector index on our vector store println!("Creating vector index..."); diff --git a/rig-integrations/rig-neo4j/src/lib.rs b/rig-integrations/rig-neo4j/src/lib.rs index 478715f1a..1c1786321 100644 --- a/rig-integrations/rig-neo4j/src/lib.rs +++ b/rig-integrations/rig-neo4j/src/lib.rs @@ -27,7 +27,7 @@ //! The index name must be unique among both indexes and constraints. //! ❗A newly created index is not immediately available but is created in the background. //! -//! ```cypher +//! ```text //! CREATE VECTOR INDEX moviePlots //! FOR (m:Movie) //! ON m.embedding @@ -39,7 +39,7 @@ //! //! ## Simple example: //! More examples can be found in the [/examples](https://github.com/0xPlaygrounds/rig/tree/main/rig-neo4j/examples) folder. -//! ``` +//! ```ignore //! use rig_neo4j::{vector_index::*, Neo4jClient}; //! use neo4rs::ConfigBuilder; //! use rig::{providers::openai::*, vector_store::VectorStoreIndex}; @@ -368,8 +368,13 @@ impl Neo4jClient { model.ndims() ); } + let embedding_property = index.properties.first().ok_or_else(|| { + VectorStoreError::DatastoreError(Box::new(std::io::Error::other( + "Neo4j index is missing an embedding property", + ))) + })?; IndexConfig::new(index.name.clone()) - .embedding_property(index.properties.first().unwrap()) + .embedding_property(embedding_property) .similarity_function(VectorSimilarityFunction::from_str( &index.options.index_config.vector_similarity_function, )?) diff --git a/rig-integrations/rig-neo4j/src/vector_index.rs b/rig-integrations/rig-neo4j/src/vector_index.rs index b4685b322..e9a69e4d7 100644 --- a/rig-integrations/rig-neo4j/src/vector_index.rs +++ b/rig-integrations/rig-neo4j/src/vector_index.rs @@ -122,7 +122,7 @@ where /// See [Query vector index](https://neo4j.com/docs/cypher-manual/current/indexes/semantic-indexes/vector-indexes/#query-vector-index) for more information. /// /// Query template: - /// ``` + /// ```text /// CALL db.index.vector.queryNodes($index_name, $num_candidates, $queryVector) /// YIELD node, score /// WHERE {where_clause} diff --git a/rig-integrations/rig-neo4j/tests/integration_tests.rs b/rig-integrations/rig-neo4j/tests/integration_tests.rs index 32575ba3d..58d81bec8 100644 --- a/rig-integrations/rig-neo4j/tests/integration_tests.rs +++ b/rig-integrations/rig-neo4j/tests/integration_tests.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + use serde_json::json; use testcontainers::{ GenericImage, ImageExt, @@ -18,6 +26,16 @@ use rig_neo4j::{Neo4jClient, ToBoltType}; const BOLT_PORT: u16 = 7687; const HTTP_PORT: u16 = 7474; +fn skip_if_docker_unavailable(test_name: &str) -> bool { + let docker_socket = std::path::Path::new("/var/run/docker.sock"); + if std::env::var_os("DOCKER_HOST").is_some() || docker_socket.exists() { + return false; + } + + eprintln!("skipping {test_name}: Docker is unavailable"); + true +} + #[derive(Embed, Clone, serde::Deserialize, Debug)] struct Word { id: String, @@ -27,6 +45,10 @@ struct Word { #[tokio::test] async fn vector_search_test() { + if skip_if_docker_unavailable("vector_search_test") { + return; + } + // Setup a local Neo 4J container for testing. NOTE: docker service must be running. let container = GenericImage::new("neo4j", "latest") .with_wait_for(WaitFor::Duration { diff --git a/rig-integrations/rig-postgres/examples/vector_search_postgres.rs b/rig-integrations/rig-postgres/examples/vector_search_postgres.rs index 14f0d58f0..92f06aaf1 100644 --- a/rig-integrations/rig-postgres/examples/vector_search_postgres.rs +++ b/rig-integrations/rig-postgres/examples/vector_search_postgres.rs @@ -35,17 +35,16 @@ async fn main() -> Result<(), anyhow::Error> { dotenvy::dotenv().ok(); // Create OpenAI client - let openai_client = openai::Client::from_env(); + let openai_client = openai::Client::from_env()?; let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_3_SMALL); // setup Postgres - let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL not set"); + let database_url = std::env::var("DATABASE_URL")?; let pool = PgPoolOptions::new() .max_connections(50) .idle_timeout(std::time::Duration::from_secs(5)) .connect(&database_url) - .await - .expect("Failed to create postgres pool"); + .await?; // make sure database is setup sqlx::migrate!("./examples/migrations").run(&pool).await?; @@ -78,11 +77,9 @@ async fn main() -> Result<(), anyhow::Error> { }]; let documents = EmbeddingsBuilder::new(model.clone()) - .documents(words) - .unwrap() + .documents(words)? .build() - .await - .expect("Failed to create embeddings"); + .await?; // delete documents from table to have a clean start (optional, not recommended for production) sqlx::query("TRUNCATE documents").execute(&pool).await?; diff --git a/rig-integrations/rig-postgres/src/lib.rs b/rig-integrations/rig-postgres/src/lib.rs index a4a5a6cfb..767545802 100644 --- a/rig-integrations/rig-postgres/src/lib.rs +++ b/rig-integrations/rig-postgres/src/lib.rs @@ -184,15 +184,17 @@ fn bind_value( value: Value, ) -> QueryAs<'_, Postgres, S, PgArguments> { match value { - Value::Null => unreachable!(), + Value::Null => builder.bind(Option::::None), Value::Bool(b) => builder.bind(b), Value::Number(num) => { if let Some(n) = num.as_f64() { builder.bind(n) } else if let Some(n) = num.as_i64() { builder.bind(n) + } else if let Some(n) = num.as_u64() { + builder.bind(n as i64) } else { - unreachable!() + builder.bind(num.to_string()) } } Value::String(s) => builder.bind(s), @@ -341,7 +343,7 @@ where ) -> Result<(), VectorStoreError> { for (document, embeddings) in documents { let id = Uuid::new_v4(); - let json_document = serde_json::to_value(&document).unwrap(); + let json_document = serde_json::to_value(&document)?; for embedding in embeddings { let embedding_text = embedding.document; diff --git a/rig-integrations/rig-postgres/tests/integration_tests.rs b/rig-integrations/rig-postgres/tests/integration_tests.rs index 9d7d753a8..2f76fc42e 100644 --- a/rig-integrations/rig-postgres/tests/integration_tests.rs +++ b/rig-integrations/rig-postgres/tests/integration_tests.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + use rig::client::EmbeddingsClient; use rig::providers::openai; use rig::vector_store::request::VectorSearchRequest; @@ -18,6 +26,16 @@ use testcontainers::{ const POSTGRES_PORT: u16 = 5432; +fn skip_if_docker_unavailable(test_name: &str) -> bool { + let docker_socket = std::path::Path::new("/var/run/docker.sock"); + if std::env::var_os("DOCKER_HOST").is_some() || docker_socket.exists() { + return false; + } + + eprintln!("skipping {test_name}: Docker is unavailable"); + true +} + #[derive(Embed, Clone, Serialize, Deserialize, Debug, PartialEq)] struct Word { id: String, @@ -28,6 +46,10 @@ struct Word { #[tokio::test] async fn vector_search_test() { + if skip_if_docker_unavailable("vector_search_test") { + return; + } + let container = start_container().await; let host = container.get_host().await.unwrap().to_string(); diff --git a/rig-integrations/rig-qdrant/examples/qdrant_vector_search.rs b/rig-integrations/rig-qdrant/examples/qdrant_vector_search.rs index 304fde8b1..5bd4ec914 100644 --- a/rig-integrations/rig-qdrant/examples/qdrant_vector_search.rs +++ b/rig-integrations/rig-qdrant/examples/qdrant_vector_search.rs @@ -48,7 +48,7 @@ async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client. // Get your API key from https://platform.openai.com/api-keys - let openai_client = Client::from_env(); + let openai_client = Client::from_env()?; let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); diff --git a/rig-integrations/rig-qdrant/tests/integration_tests.rs b/rig-integrations/rig-qdrant/tests/integration_tests.rs index fb1573ef1..ba79b544e 100644 --- a/rig-integrations/rig-qdrant/tests/integration_tests.rs +++ b/rig-integrations/rig-qdrant/tests/integration_tests.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + use serde_json::json; use testcontainers::{ GenericImage, @@ -22,6 +30,16 @@ const QDRANT_PORT: u16 = 6333; const QDRANT_PORT_SECONDARY: u16 = 6334; const COLLECTION_NAME: &str = "rig-collection"; +fn skip_if_docker_unavailable(test_name: &str) -> bool { + let docker_socket = std::path::Path::new("/var/run/docker.sock"); + if std::env::var_os("DOCKER_HOST").is_some() || docker_socket.exists() { + return false; + } + + eprintln!("skipping {test_name}: Docker is unavailable"); + true +} + #[derive(Embed, Clone, serde::Deserialize, serde::Serialize, Debug)] struct Word { id: String, @@ -31,6 +49,10 @@ struct Word { #[tokio::test] async fn vector_search_test() { + if skip_if_docker_unavailable("vector_search_test") { + return; + } + // Setup a local qdrant container for testing. NOTE: docker service must be running. let container = GenericImage::new("qdrant/qdrant", "latest") .with_wait_for(WaitFor::Duration { diff --git a/rig-integrations/rig-s3vectors/examples/s3vectors_vector_search.rs b/rig-integrations/rig-s3vectors/examples/s3vectors_vector_search.rs index 0087f072c..140326cc8 100644 --- a/rig-integrations/rig-s3vectors/examples/s3vectors_vector_search.rs +++ b/rig-integrations/rig-s3vectors/examples/s3vectors_vector_search.rs @@ -21,10 +21,8 @@ struct Word { #[tokio::main] async fn main() -> Result<(), anyhow::Error> { - let access_key_id = env::var("AWS_ACCESS_KEY_ID") - .expect("AWS_ACCESS_KEY_ID does not exist as an environment variable"); - let secret_access_key = env::var("AWS_SECRET_ACCESS_KEY") - .expect("AWS_ACCESS_KEY_ID does not exist as an environment variable"); + let access_key_id = env::var("AWS_ACCESS_KEY_ID")?; + let secret_access_key = env::var("AWS_SECRET_ACCESS_KEY")?; let credentials = Credentials::new(access_key_id, secret_access_key, None, None, "test"); let region_provider = RegionProviderChain::default_provider().or_else("us-east-1"); @@ -43,7 +41,7 @@ async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client. // Get your API key from https://platform.openai.com/api-keys - let openai_client = OpenAIClient::from_env(); + let openai_client = OpenAIClient::from_env()?; let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); diff --git a/rig-integrations/rig-s3vectors/src/lib.rs b/rig-integrations/rig-s3vectors/src/lib.rs index 098da46fb..0b5572318 100644 --- a/rig-integrations/rig-s3vectors/src/lib.rs +++ b/rig-integrations/rig-s3vectors/src/lib.rs @@ -220,21 +220,17 @@ fn document_to_json_value(value: &Document) -> Value { match value { Document::Null => Value::Null, Document::Bool(b) => Value::Bool(*b), - Document::Number(n) => { - let res = match n { - aws_smithy_types::Number::Float(f) => { - serde_json::Number::from_f64(f.to_owned()).unwrap() - } - aws_smithy_types::Number::NegInt(i) => { - serde_json::Number::from_i128(*i as i128).unwrap() - } - aws_smithy_types::Number::PosInt(u) => { - serde_json::Number::from_u128(*u as u128).unwrap() - } - }; - - serde_json::Value::Number(res) - } + Document::Number(n) => match n { + aws_smithy_types::Number::Float(f) => serde_json::Number::from_f64(*f) + .map(Value::Number) + .unwrap_or_else(|| Value::String(f.to_string())), + aws_smithy_types::Number::NegInt(i) => { + serde_json::Value::Number(serde_json::Number::from(*i)) + } + aws_smithy_types::Number::PosInt(u) => { + serde_json::Value::Number(serde_json::Number::from(*u)) + } + }, Document::String(s) => Value::String(s.clone()), Document::Array(arr) => Value::Array(arr.iter().map(document_to_json_value).collect()), Document::Object(obj) => { @@ -285,29 +281,41 @@ where query_builder = query_builder.filter(filter.inner().clone()) } - let query = query_builder.send().await.unwrap(); + let query = query_builder + .send() + .await + .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; let res: Vec<(f64, String, T)> = query .vectors .into_iter() - .filter(|vector| { - req.threshold().is_none_or(|threshold| { - (vector - .distance() - .expect("vector distance should always exist") as f64) - >= threshold - }) - }) .map(|x| { - let distance = x.distance.expect("vector distance should always exist") as f64; - let val = - document_to_json_value(&x.metadata.expect("metadata should always exist")); + let distance = x.distance.ok_or_else(|| { + VectorStoreError::DatastoreError(Box::new(std::io::Error::other( + "S3Vectors response missing distance", + ))) + })? as f64; + + if req + .threshold() + .is_some_and(|threshold| distance < threshold) + { + return Ok(None); + } - let metadata: T = serde_json::from_value(val) - .expect("converting JSON from S3Vectors to valid T should always work"); + let metadata_document = x.metadata.ok_or_else(|| { + VectorStoreError::DatastoreError(Box::new(std::io::Error::other( + "S3Vectors response missing metadata", + ))) + })?; + let val = document_to_json_value(&metadata_document); + let metadata: T = serde_json::from_value(val)?; - (distance, x.key, metadata) + Ok(Some((distance, x.key, metadata))) }) + .collect::, VectorStoreError>>()? + .into_iter() + .flatten() .collect(); Ok(res) @@ -343,24 +351,33 @@ where query_builder = query_builder.filter(filter.inner().clone()) } - let query = query_builder.send().await.unwrap(); + let query = query_builder + .send() + .await + .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; let res: Vec<(f64, String)> = query .vectors .into_iter() - .filter(|vector| { - req.threshold().is_none_or(|threshold| { - (vector - .distance() - .expect("vector distance should always exist") as f64) - >= threshold - }) - }) .map(|x| { - let distance = x.distance.expect("vector distance should always exist") as f64; + let distance = x.distance.ok_or_else(|| { + VectorStoreError::DatastoreError(Box::new(std::io::Error::other( + "S3Vectors response missing distance", + ))) + })? as f64; + + if req + .threshold() + .is_some_and(|threshold| distance < threshold) + { + return Ok(None); + } - (distance, x.key) + Ok(Some((distance, x.key))) }) + .collect::, VectorStoreError>>()? + .into_iter() + .flatten() .collect(); Ok(res) diff --git a/rig-integrations/rig-scylladb/examples/scylladb_vector_search.rs b/rig-integrations/rig-scylladb/examples/scylladb_vector_search.rs index 66775c1a0..86e99d382 100644 --- a/rig-integrations/rig-scylladb/examples/scylladb_vector_search.rs +++ b/rig-integrations/rig-scylladb/examples/scylladb_vector_search.rs @@ -22,12 +22,10 @@ async fn main() -> Result<(), anyhow::Error> { // Create ScyllaDB session // In production, you would use your ScyllaDB cluster endpoints - let session = create_session("127.0.0.1:9042") - .await - .expect("Failed to create ScyllaDB session"); + let session = create_session("127.0.0.1:9042").await?; // Create OpenAI client and embedding model - let openai_client = Client::from_env(); + let openai_client = Client::from_env()?; let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); // Create ScyllaDB vector store @@ -38,8 +36,7 @@ async fn main() -> Result<(), anyhow::Error> { "words", // table 1536, // dimensions for text-embedding-ada-002 ) - .await - .expect("Failed to create ScyllaDB vector store"); + .await?; // Create sample word definitions let words = vec![ @@ -86,8 +83,7 @@ async fn main() -> Result<(), anyhow::Error> { vector_store .insert_documents(documents_with_embeddings) - .await - .expect("Failed to insert documents"); + .await?; tracing::info!("Documents inserted successfully!"); @@ -99,10 +95,7 @@ async fn main() -> Result<(), anyhow::Error> { .build(); tracing::info!("Searching for: '{}'", query); - let results = vector_store - .top_n::(req.clone()) - .await - .expect("Failed to search vectors"); + let results = vector_store.top_n::(req.clone()).await?; tracing::info!("Top 3 similar definitions:"); for (i, (score, id, word)) in results.iter().enumerate() { @@ -117,10 +110,7 @@ async fn main() -> Result<(), anyhow::Error> { // Test ID-only search tracing::info!("Searching for IDs only..."); - let id_results = vector_store - .top_n_ids(req) - .await - .expect("Failed to search vector IDs"); + let id_results = vector_store.top_n_ids(req).await?; tracing::info!("Top 2 similar document IDs:"); for (i, (score, id)) in id_results.iter().enumerate() { @@ -135,10 +125,7 @@ async fn main() -> Result<(), anyhow::Error> { .samples(2) .build(); - let db_results = vector_store - .top_n::(req) - .await - .expect("Failed to search vectors"); + let db_results = vector_store.top_n::(req).await?; tracing::info!("Top 2 similar definitions:"); for (i, (score, id, word)) in db_results.iter().enumerate() { diff --git a/rig-integrations/rig-scylladb/src/lib.rs b/rig-integrations/rig-scylladb/src/lib.rs index d225e3261..6213a3591 100644 --- a/rig-integrations/rig-scylladb/src/lib.rs +++ b/rig-integrations/rig-scylladb/src/lib.rs @@ -485,7 +485,7 @@ where } // Sort by similarity score (descending) and take top n - candidates.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap()); + candidates.sort_by(|a, b| b.0.total_cmp(&a.0)); candidates.truncate(req.samples() as usize); Ok(candidates) @@ -536,7 +536,7 @@ where } // Sort by similarity score (descending) and take top n - candidates.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap()); + candidates.sort_by(|a, b| b.0.total_cmp(&a.0)); candidates.truncate(req.samples() as usize); Ok(candidates) diff --git a/rig-integrations/rig-scylladb/tests/integration_tests.rs b/rig-integrations/rig-scylladb/tests/integration_tests.rs index 7ced709b6..0788b57fe 100644 --- a/rig-integrations/rig-scylladb/tests/integration_tests.rs +++ b/rig-integrations/rig-scylladb/tests/integration_tests.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + use rig::client::EmbeddingsClient; use rig::providers::openai; use rig::vector_store::request::VectorSearchRequest; diff --git a/rig-integrations/rig-sqlite/examples/vector_search_sqlite.rs b/rig-integrations/rig-sqlite/examples/vector_search_sqlite.rs index 251272fab..b512aa475 100644 --- a/rig-integrations/rig-sqlite/examples/vector_search_sqlite.rs +++ b/rig-integrations/rig-sqlite/examples/vector_search_sqlite.rs @@ -57,7 +57,7 @@ async fn main() -> Result<(), anyhow::Error> { .init(); // Initialize OpenAI client - let openai_client = Client::from_env(); + let openai_client = Client::from_env()?; // Initialize the `sqlite-vec`extension // See: https://alexgarcia.xyz/sqlite-vec/rust.html diff --git a/rig-integrations/rig-sqlite/src/lib.rs b/rig-integrations/rig-sqlite/src/lib.rs index ba27e0d9a..3a671badb 100644 --- a/rig-integrations/rig-sqlite/src/lib.rs +++ b/rig-integrations/rig-sqlite/src/lib.rs @@ -403,8 +403,10 @@ impl SqliteSearchFilter { Value::Real(float) } else if let Some(int) = n.as_i64() { Value::Integer(int) + } else if let Some(int) = n.as_u64() { + Value::Integer(int as i64) } else { - unreachable!() + Value::Text(n.to_string()) }), Array(arr) => { let blob = serde_json::to_vec(&arr) @@ -435,8 +437,9 @@ impl SqliteSearchFilter { /// It uses the `sqlite-vec` extension to enable vector similarity search capabilities. /// /// # Example -/// ```rust +/// ```no_run /// use rig::{ +/// client::EmbeddingsClient, /// embeddings::EmbeddingsBuilder, /// providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, /// vector_store::{InsertDocuments, VectorStoreIndex}, @@ -447,6 +450,7 @@ impl SqliteSearchFilter { /// use serde::{Deserialize, Serialize}; /// use tokio_rusqlite::Connection; /// +/// # async fn example() -> anyhow::Result<()> { /// #[derive(Embed, Clone, Debug, Deserialize, Serialize)] /// struct Document { /// id: String, @@ -479,11 +483,11 @@ impl SqliteSearchFilter { /// } /// /// let conn = Connection::open("vector_store.db").await?; -/// let openai_client = Client::new("YOUR_API_KEY"); +/// let openai_client = Client::new("YOUR_API_KEY")?; /// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); /// /// // Initialize vector store -/// let vector_store = SqliteVectorStore::new(conn, &model).await?; +/// let vector_store: SqliteVectorStore<_, Document> = SqliteVectorStore::new(conn, &model).await?; /// /// // Create documents /// let documents = vec![ @@ -511,8 +515,12 @@ impl SqliteSearchFilter { /// let req = VectorSearchRequest::builder() /// .query("Example query") /// .samples(2) -/// .build()?; +/// .build(); /// let results = index.top_n::(req).await?; +/// # let _ = results; +/// # Ok(()) +/// # } +/// # let _ = example(); /// ``` pub struct SqliteVectorIndex where diff --git a/rig-integrations/rig-sqlite/tests/integration_test.rs b/rig-integrations/rig-sqlite/tests/integration_test.rs index fe4637963..17a4500f7 100644 --- a/rig-integrations/rig-sqlite/tests/integration_test.rs +++ b/rig-integrations/rig-sqlite/tests/integration_test.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + use rig::vector_store::request::{SearchFilter, VectorSearchRequest}; use serde_json::json; diff --git a/rig-integrations/rig-surrealdb/examples/vector_search_surreal.rs b/rig-integrations/rig-surrealdb/examples/vector_search_surreal.rs index 6cf867618..544381cf4 100644 --- a/rig-integrations/rig-surrealdb/examples/vector_search_surreal.rs +++ b/rig-integrations/rig-surrealdb/examples/vector_search_surreal.rs @@ -30,7 +30,7 @@ impl std::fmt::Display for WordDefinition { #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Create OpenAI client - let openai_client = openai::Client::from_env(); + let openai_client = openai::Client::from_env()?; let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); let surreal = Surreal::new::(()).await?; @@ -53,11 +53,9 @@ async fn main() -> Result<(), anyhow::Error> { }]; let documents = EmbeddingsBuilder::new(model.clone()) - .documents(words) - .unwrap() + .documents(words)? .build() - .await - .expect("Failed to create embeddings"); + .await?; // init vector store let vector_store = SurrealVectorStore::with_defaults(model, surreal); @@ -81,7 +79,13 @@ async fn main() -> Result<(), anyhow::Error> { } // Use the midpoint as similarity threshold to guarantee exactly one result is returned. - let midpoint = (results[0].0 + results[1].0) / 2.0; + let Some(first_result) = results.first() else { + return Err(anyhow::anyhow!("expected at least one result")); + }; + let Some(second_result) = results.get(1) else { + return Err(anyhow::anyhow!("expected at least two results")); + }; + let midpoint = (first_result.0 + second_result.0) / 2.0; println!( "Attempting vector search with cosine similarity threshold of {midpoint} and query: {query}" diff --git a/rig-integrations/rig-surrealdb/examples/vector_store.rs b/rig-integrations/rig-surrealdb/examples/vector_store.rs index 2369fdda2..0b4eb1cc6 100644 --- a/rig-integrations/rig-surrealdb/examples/vector_store.rs +++ b/rig-integrations/rig-surrealdb/examples/vector_store.rs @@ -29,7 +29,7 @@ impl std::fmt::Display for TopicDefinition { #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Create OpenAI client - let openai_client = openai::Client::from_env(); + let openai_client = openai::Client::from_env()?; let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); let surreal = Surreal::new::(()).await?; @@ -71,14 +71,20 @@ async fn main() -> Result<(), anyhow::Error> { let results = vector_store.top_n::(req).await?; assert_eq!(results.len(), 3); - assert_eq!(results[0].2.topic, "pasta carbonara"); + let Some(first_result) = results.first() else { + return Err(anyhow::anyhow!("expected at least one result")); + }; + assert_eq!(first_result.2.topic, "pasta carbonara"); println!("{} results for query: {}", results.len(), query); for (distance, _id, doc) in results.iter() { println!("Result distance {distance} for topic: {doc}"); } - let midpoint = (results[0].0 + results[1].0) / 2.0; + let Some(second_result) = results.get(1) else { + return Err(anyhow::anyhow!("expected at least two results")); + }; + let midpoint = (first_result.0 + second_result.0) / 2.0; println!( "Attempting vector search with cosine similarity threshold of {midpoint} and query: {query}" @@ -93,7 +99,10 @@ async fn main() -> Result<(), anyhow::Error> { println!("{} results for query: {}", results.len(), query); assert_eq!(results.len(), 1); - assert_eq!(results[0].2.topic, "pasta carbonara"); + let Some(filtered_result) = results.first() else { + return Err(anyhow::anyhow!("expected one filtered result")); + }; + assert_eq!(filtered_result.2.topic, "pasta carbonara"); for (distance, _id, doc) in results.iter() { println!("Result distance {distance} for topic: {doc}"); diff --git a/rig-integrations/rig-surrealdb/src/lib.rs b/rig-integrations/rig-surrealdb/src/lib.rs index 49b2de188..794c78dc6 100644 --- a/rig-integrations/rig-surrealdb/src/lib.rs +++ b/rig-integrations/rig-surrealdb/src/lib.rs @@ -459,6 +459,7 @@ mod tests { } } + #[allow(clippy::panic)] #[test] fn filter_from_json_preserves_nested_values() { let filter = match SurrealSearchFilter::try_from(Filter::Eq( @@ -481,6 +482,7 @@ mod tests { assert!(sql.contains("tags: ['surreal', 'json']")); } + #[allow(clippy::panic)] #[tokio::test] async fn surreal_vector_store_supports_dynamic_context_filters() { fn assert_dyn(_: T) {} diff --git a/rig-integrations/rig-vectorize/examples/vectorize_vector_search.rs b/rig-integrations/rig-vectorize/examples/vectorize_vector_search.rs index e163abeb5..672223302 100644 --- a/rig-integrations/rig-vectorize/examples/vectorize_vector_search.rs +++ b/rig-integrations/rig-vectorize/examples/vectorize_vector_search.rs @@ -30,7 +30,7 @@ struct Word { #[tokio::main] async fn main() -> Result<(), anyhow::Error> { - let openai_client = Client::from_env(); + let openai_client = Client::from_env()?; let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_3_SMALL); let vector_store = VectorizeVectorStore::new( diff --git a/rig-integrations/rig-vectorize/src/client/filter.rs b/rig-integrations/rig-vectorize/src/client/filter.rs index eccdac68d..001768c7f 100644 --- a/rig-integrations/rig-vectorize/src/client/filter.rs +++ b/rig-integrations/rig-vectorize/src/client/filter.rs @@ -191,16 +191,16 @@ mod tests { let combined = filter1.and(filter2); let result = combined.into_inner(); - let obj = result.as_object().unwrap(); + let Value::Object(obj) = result else { + assert!(false, "combined filter should serialize to an object"); + return; + }; // Both keys should be present (implicit AND) assert!(obj.contains_key("category")); assert!(obj.contains_key("score")); - assert_eq!( - obj.get("category").unwrap(), - &json!({ "$eq": "programming" }) - ); - assert_eq!(obj.get("score").unwrap(), &json!({ "$gt": 0.5 })); + assert_eq!(obj.get("category"), Some(&json!({ "$eq": "programming" }))); + assert_eq!(obj.get("score"), Some(&json!({ "$gt": 0.5 }))); } #[test] @@ -213,12 +213,21 @@ mod tests { let result = combined.validate(); assert!(result.is_err()); - let err = result.unwrap_err(); + let err = match result { + Err(err) => err, + Ok(()) => { + assert!(false, "OR filters should fail validation"); + return; + } + }; match err { VectorizeError::UnsupportedFilterOperation(msg) => { assert!(msg.contains("OR")); } - _ => panic!("Expected UnsupportedFilterOperation error"), + other => assert!( + false, + "expected UnsupportedFilterOperation error, got {other:?}" + ), } } @@ -242,7 +251,10 @@ mod tests { .and(VectorizeFilter::lt("price", json!(100))); let result = filter.into_inner(); - let obj = result.as_object().unwrap(); + let Value::Object(obj) = result else { + assert!(false, "combined filter should serialize to an object"); + return; + }; assert_eq!(obj.len(), 3); assert!(obj.contains_key("category")); diff --git a/rig-integrations/rig-vectorize/tests/integration_tests.rs b/rig-integrations/rig-vectorize/tests/integration_tests.rs index 4154fa588..d9579e6cc 100644 --- a/rig-integrations/rig-vectorize/tests/integration_tests.rs +++ b/rig-integrations/rig-vectorize/tests/integration_tests.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + //! Integration tests for rig-vectorize. //! //! These tests require a real Cloudflare Vectorize index and valid credentials. diff --git a/rig-integrations/rig-vertexai/Cargo.toml b/rig-integrations/rig-vertexai/Cargo.toml index 6ccad4a39..3daa24020 100644 --- a/rig-integrations/rig-vertexai/Cargo.toml +++ b/rig-integrations/rig-vertexai/Cargo.toml @@ -15,11 +15,11 @@ google-cloud-auth = { workspace = true } rig-core = { path = "../../rig/rig-core", version = "0.35.0", default-features = false } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } +thiserror = { workspace = true } tokio = { workspace = true, features = ["full"] } tracing = { workspace = true } [dev-dependencies] anyhow = { workspace = true } schemars = { workspace = true } -thiserror = { workspace = true } tracing-subscriber = { workspace = true } diff --git a/rig-integrations/rig-vertexai/examples/completion_vertexai.rs b/rig-integrations/rig-vertexai/examples/completion_vertexai.rs index cb95766a9..38dbe3a71 100644 --- a/rig-integrations/rig-vertexai/examples/completion_vertexai.rs +++ b/rig-integrations/rig-vertexai/examples/completion_vertexai.rs @@ -8,7 +8,7 @@ async fn main() -> Result<(), anyhow::Error> { tracing_subscriber::fmt().with_target(false).init(); // Uses ADC credentials and expects GOOGLE_CLOUD_PROJECT to be set. See Client::builder() for more granular control. - let client = Client::from_env(); + let client = Client::from_env()?; let model = client.completion_model(GEMINI_2_5_FLASH_LITE); let request = model diff --git a/rig-integrations/rig-vertexai/examples/tool_vertexai.rs b/rig-integrations/rig-vertexai/examples/tool_vertexai.rs index 842061b53..a8dcaec42 100644 --- a/rig-integrations/rig-vertexai/examples/tool_vertexai.rs +++ b/rig-integrations/rig-vertexai/examples/tool_vertexai.rs @@ -7,6 +7,7 @@ use rig::{ use rig_vertexai::{Client, completion::GEMINI_2_5_FLASH_LITE}; use schemars::{JsonSchema, schema_for}; use serde::{Deserialize, Serialize}; +use serde_json::json; #[derive(Deserialize, JsonSchema)] struct OperationArgs { @@ -31,8 +32,7 @@ impl Tool for Adder { ToolDefinition { name: "add".to_string(), description: "Add x and y together".to_string(), - parameters: serde_json::to_value(schema_for!(OperationArgs)) - .expect("converting JSON schema to JSON value should never fail"), + parameters: json!(schema_for!(OperationArgs)), } } @@ -48,7 +48,7 @@ async fn main() -> Result<(), anyhow::Error> { tracing_subscriber::fmt().with_target(false).init(); // Create Vertex AI client using implicit credentials - let client = Client::from_env(); + let client = Client::from_env()?; // Create agent with a calculator tool let calculator_agent = client diff --git a/rig-integrations/rig-vertexai/src/client.rs b/rig-integrations/rig-vertexai/src/client.rs index 4a2d6503e..f0f4c35d8 100644 --- a/rig-integrations/rig-vertexai/src/client.rs +++ b/rig-integrations/rig-vertexai/src/client.rs @@ -5,6 +5,7 @@ use google_cloud_auth::credentials::Credentials; use rig::client::{CompletionClient, Nothing}; use rig::prelude::*; use std::sync::Arc; +use thiserror::Error; use tokio::sync::OnceCell; // Env vars and terminology (location, project) chosen to match google genai client @@ -18,22 +19,42 @@ use tokio::sync::OnceCell; /// Regional endpoints may be preferred for data residency requirements or to use regional quotas. pub const DEFAULT_LOCATION: &str = "global"; +#[derive(Clone, Debug, Error)] +pub enum VertexAiClientError { + #[error( + "Google Cloud project is required. Set it via `ClientBuilder::with_project()` or `GOOGLE_CLOUD_PROJECT`" + )] + MissingProject, + #[error("failed to build source credentials: {0}")] + SourceCredentials(String), + #[error("failed to build impersonated credentials: {0}")] + ImpersonatedCredentials(String), + #[error("failed to build Vertex AI prediction service: {0}")] + PredictionService(String), + #[error( + "Vertex AI uses Application Default Credentials (ADC). Use `Client::from_env()` for default credentials or `Client::builder().with_credentials(...).build()` for explicit credentials." + )] + InvalidInput, +} + /// Helper function to build credentials with optional service account impersonation. -fn build_credentials(explicit_creds: Option) -> Result { +fn build_credentials( + explicit_creds: Option, +) -> Result { if let Some(creds) = explicit_creds { Ok(creds) } else { // Build default credentials let source_credentials = credentials::Builder::default() .build() - .map_err(|e| format!("Failed to build source credentials: {e}"))?; + .map_err(|e| VertexAiClientError::SourceCredentials(e.to_string()))?; // Check for service account impersonation if let Ok(service_account) = std::env::var("GOOGLE_CLOUD_SERVICE_ACCOUNT") { credentials::impersonated::Builder::from_source_credentials(source_credentials) .with_target_principal(service_account) .build() - .map_err(|e| format!("Failed to build impersonated credentials: {e}")) + .map_err(|e| VertexAiClientError::ImpersonatedCredentials(e.to_string())) } else { Ok(source_credentials) } @@ -85,13 +106,11 @@ impl ClientBuilder { /// Build the client with the configured values, falling back to environment variables where not set. /// /// The Vertex AI client is built lazily on first use via `get_inner()`. - pub fn build(self) -> Result { + pub fn build(self) -> Result { let project = self .project .or_else(|| std::env::var("GOOGLE_CLOUD_PROJECT").ok()) - .ok_or_else(|| { - "Google Cloud project is required. Set it via ClientBuilder::with_project() or GOOGLE_CLOUD_PROJECT environment variable".to_string() - })?; + .ok_or(VertexAiClientError::MissingProject)?; let location = self .location @@ -120,7 +139,8 @@ pub struct Client { project: String, location: String, credentials: Credentials, - pub(crate) vertex_client: Arc>, + pub(crate) vertex_client: + Arc>>, } impl Client { @@ -134,7 +154,7 @@ impl Client { /// Example: /// ```no_run /// # use rig_vertexai::Client; - /// # fn example() -> Result<(), String> { + /// # fn example() -> Result<(), rig_vertexai::client::VertexAiClientError> { /// // Use all env vars /// let client = Client::builder().build()?; /// @@ -160,11 +180,8 @@ impl Client { /// - `GOOGLE_CLOUD_LOCATION` (optional, defaults to "global") /// - `GOOGLE_CLOUD_SERVICE_ACCOUNT` (optional, for service account impersonation) /// - /// Panics if the environment is improperly configured. For error handling, use `Client::builder().build()`. - pub fn new() -> Self { - ClientBuilder::new() - .build() - .expect("Failed to build Vertex AI client. Make sure GOOGLE_CLOUD_PROJECT is set and credentials are configured (e.g., via 'gcloud auth application-default login')") + pub fn new() -> Result { + ClientBuilder::new().build() } /// Create a client using environment variables for project, location, and credentials. @@ -174,7 +191,7 @@ impl Client { /// - `GOOGLE_CLOUD_PROJECT` (required) /// - `GOOGLE_CLOUD_LOCATION` (optional, defaults to "global") /// - `GOOGLE_CLOUD_SERVICE_ACCOUNT` (optional, for service account impersonation) - pub fn from_env() -> Self { + pub fn from_env() -> Result { ::from_env() } @@ -186,7 +203,9 @@ impl Client { &self.location } - pub async fn get_inner(&self) -> &vertexai::client::PredictionService { + pub async fn get_inner( + &self, + ) -> Result<&vertexai::client::PredictionService, VertexAiClientError> { let credentials = self.credentials.clone(); self.vertex_client .get_or_init(|| async { @@ -195,35 +214,30 @@ impl Client { builder .build() .await - .expect("Failed to build Vertex AI client. Make sure you have Google Cloud credentials configured (e.g., via 'gcloud auth application-default login')") + .map_err(|error| VertexAiClientError::PredictionService(error.to_string())) }) .await - } -} - -impl Default for Client { - fn default() -> Self { - Client::new() + .as_ref() + .map_err(Clone::clone) } } impl ProviderClient for Client { type Input = Nothing; + type Error = VertexAiClientError; - fn from_env() -> Self + fn from_env() -> Result where Self: Sized, { Client::new() } - fn from_val(_: Self::Input) -> Self + fn from_val(_: Self::Input) -> Result where Self: Sized, { - panic!( - "Vertex AI uses Application Default Credentials (ADC). Use `Client::from_env()` for default credentials, or `Client::new().with_credentials(...).build()` for custom credentials." - ); + Err(VertexAiClientError::InvalidInput) } } diff --git a/rig-integrations/rig-vertexai/src/completion.rs b/rig-integrations/rig-vertexai/src/completion.rs index 88c5cb93b..5a26a5d33 100644 --- a/rig-integrations/rig-vertexai/src/completion.rs +++ b/rig-integrations/rig-vertexai/src/completion.rs @@ -100,6 +100,7 @@ impl CompletionModelTrait for CompletionModel { .client .get_inner() .await + .map_err(|error| CompletionError::ProviderError(error.to_string()))? .generate_content() .set_model(&model_path) .set_contents(contents); diff --git a/rig-integrations/rig-vertexai/src/lib.rs b/rig-integrations/rig-vertexai/src/lib.rs index 62a267973..e32748ff0 100644 --- a/rig-integrations/rig-vertexai/src/lib.rs +++ b/rig-integrations/rig-vertexai/src/lib.rs @@ -1,3 +1,14 @@ +#![cfg_attr( + test, + allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable + ) +)] + pub mod client; pub mod completion; pub(crate) mod types; diff --git a/rig/rig-core/examples/agent.rs b/rig/rig-core/examples/agent.rs index 18c46170f..1cacb6987 100644 --- a/rig/rig-core/examples/agent.rs +++ b/rig/rig-core/examples/agent.rs @@ -12,7 +12,7 @@ const PROMPT: &str = "Entertain me!"; #[tokio::main] async fn main() -> Result<()> { - let agent = openai::Client::from_env() + let agent = openai::Client::from_env()? .agent(openai::GPT_4O) .preamble(PREAMBLE) .build(); diff --git a/rig/rig-core/examples/agent_autonomous.rs b/rig/rig-core/examples/agent_autonomous.rs index ef933f0b7..a0b918097 100644 --- a/rig/rig-core/examples/agent_autonomous.rs +++ b/rig/rig-core/examples/agent_autonomous.rs @@ -34,7 +34,7 @@ fn build_counter_extractor( #[tokio::main] async fn main() -> Result<()> { - let client = Client::from_env(); + let client = Client::from_env()?; let extractor = build_counter_extractor(&client); let mut current_number = 0; let mut step = 1; diff --git a/rig/rig-core/examples/agent_evaluator_optimizer.rs b/rig/rig-core/examples/agent_evaluator_optimizer.rs index 26b345f7d..f047c4bc3 100644 --- a/rig/rig-core/examples/agent_evaluator_optimizer.rs +++ b/rig/rig-core/examples/agent_evaluator_optimizer.rs @@ -27,7 +27,7 @@ All operations should be O(1). #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Create OpenAI client - let openai_client = Client::from_env(); + let openai_client = Client::from_env()?; let generator_agent = openai_client .agent(openai::GPT_4) @@ -65,19 +65,18 @@ async fn main() -> Result<(), anyhow::Error> { .build(); let mut memories: Vec = Vec::new(); - let mut response = generator_agent.prompt(TASK).await.unwrap(); + let mut response = generator_agent.prompt(TASK).await?; memories.push(response.clone()); loop { let eval_result = evaluator_agent .extract(&format!("{TASK}\n\n{response}")) - .await - .unwrap(); + .await?; if eval_result.evaluation_status == EvalStatus::Pass { break; } else { let context = format!("{TASK}\n\n{}", eval_result.feedback); - response = generator_agent.prompt(context).await.unwrap(); + response = generator_agent.prompt(context).await?; memories.push(response.clone()); } } diff --git a/rig/rig-core/examples/agent_orchestrator.rs b/rig/rig-core/examples/agent_orchestrator.rs index 084955330..b95283b1b 100644 --- a/rig/rig-core/examples/agent_orchestrator.rs +++ b/rig/rig-core/examples/agent_orchestrator.rs @@ -24,7 +24,7 @@ struct TaskResults { #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Create OpenAI client - let openai_client = Client::from_env(); + let openai_client = Client::from_env()?; // Note that you can also create your own semantic router for this // that uses a vector store under the hood @@ -49,7 +49,7 @@ async fn main() -> Result<(), anyhow::Error> { let specification = classify_agent.extract(" Write a product description for a new eco-friendly water bottle. The target_audience is environmentally conscious millennials and key product features are: plastic-free, insulated, lifetime warranty - ").await.unwrap(); + ").await?; let content_agent = openai_client .extractor::(openai::GPT_4) @@ -73,8 +73,7 @@ async fn main() -> Result<(), anyhow::Error> { ", task.original_task, task.style, task.guidelines )) - .await - .unwrap(); + .await?; vec.push(results); } @@ -89,8 +88,8 @@ async fn main() -> Result<(), anyhow::Error> { ) .build(); - let task_results_raw_json = serde_json::to_string_pretty(&vec).unwrap(); - let results = judge_agent.extract(&task_results_raw_json).await.unwrap(); + let task_results_raw_json = serde_json::to_string_pretty(&vec)?; + let results = judge_agent.extract(&task_results_raw_json).await?; println!("Results: {results:?}"); diff --git a/rig/rig-core/examples/agent_parallelization.rs b/rig/rig-core/examples/agent_parallelization.rs index 67aad7db4..4d460736c 100644 --- a/rig/rig-core/examples/agent_parallelization.rs +++ b/rig/rig-core/examples/agent_parallelization.rs @@ -12,7 +12,7 @@ use rig::{ use schemars::JsonSchema; -#[derive(serde::Deserialize, JsonSchema, serde::Serialize)] +#[derive(Debug, serde::Deserialize, JsonSchema, serde::Serialize)] struct DocumentScore { /// The score of the document score: f32, @@ -20,7 +20,7 @@ struct DocumentScore { #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Create OpenAI client - let openai_client = Client::from_env(); + let openai_client = Client::from_env()?; let manipulation_agent = openai_client .extractor::(openai::GPT_4) @@ -57,17 +57,25 @@ async fn main() -> Result<(), anyhow::Error> { extract(intelligent_agent) )) .map(|(statement, manip_score, dep_score, int_score)| { - format!( - " - Original statement: {statement} - Manipulation sentiment score: {} - Depression sentiment score: {} - Intelligence sentiment score: {} - ", - manip_score.unwrap().score, - dep_score.unwrap().score, - int_score.unwrap().score - ) + match (manip_score, dep_score, int_score) { + (Ok(manip_score), Ok(dep_score), Ok(int_score)) => format!( + " + Original statement: {statement} + Manipulation sentiment score: {} + Depression sentiment score: {} + Intelligence sentiment score: {} + ", + manip_score.score, dep_score.score, int_score.score + ), + (manip_score, dep_score, int_score) => format!( + " + Original statement: {statement} + Manipulation sentiment score: {manip_score:?} + Depression sentiment score: {dep_score:?} + Intelligence sentiment score: {int_score:?} + " + ), + } }); // Prompt the agent and print the response diff --git a/rig/rig-core/examples/agent_prompt_chaining.rs b/rig/rig-core/examples/agent_prompt_chaining.rs index b47e553d4..f31602eae 100644 --- a/rig/rig-core/examples/agent_prompt_chaining.rs +++ b/rig/rig-core/examples/agent_prompt_chaining.rs @@ -28,7 +28,7 @@ fn build_adder_agent( #[tokio::main] async fn main() -> Result<()> { - let client = Client::from_env(); + let client = Client::from_env()?; let seed = build_rng_agent(&client).prompt(INPUT_PROMPT).await?; let response = build_adder_agent(&client).prompt(seed.trim()).await?; diff --git a/rig/rig-core/examples/agent_routing.rs b/rig/rig-core/examples/agent_routing.rs index 6cc770514..b8266dfa7 100644 --- a/rig/rig-core/examples/agent_routing.rs +++ b/rig/rig-core/examples/agent_routing.rs @@ -40,7 +40,7 @@ fn follow_up_prompt(category: &str) -> Result<&'static str> { #[tokio::main] async fn main() -> Result<()> { - let client = Client::from_env(); + let client = Client::from_env()?; let category = build_router_agent(&client).prompt(INPUT_PROMPT).await?; let follow_up = follow_up_prompt(category.trim())?; let response = build_response_agent(&client).prompt(follow_up).await?; diff --git a/rig/rig-core/examples/agent_stream_chat.rs b/rig/rig-core/examples/agent_stream_chat.rs index 5b600d622..f7708d6a2 100644 --- a/rig/rig-core/examples/agent_stream_chat.rs +++ b/rig/rig-core/examples/agent_stream_chat.rs @@ -34,7 +34,7 @@ fn sample_history() -> Vec { #[tokio::main] async fn main() -> Result<()> { - let agent = openai::Client::from_env() + let agent = openai::Client::from_env()? .agent(openai::GPT_4) .preamble(PREAMBLE) .build(); diff --git a/rig/rig-core/examples/agent_with_agent_tool.rs b/rig/rig-core/examples/agent_with_agent_tool.rs index da08ddc68..5fd61cb44 100644 --- a/rig/rig-core/examples/agent_with_agent_tool.rs +++ b/rig/rig-core/examples/agent_with_agent_tool.rs @@ -64,10 +64,10 @@ impl Tool for Subtract { type Output = i32; async fn definition(&self, _prompt: String) -> ToolDefinition { - serde_json::from_value(json!({ - "name": "subtract", - "description": "Subtract y from x (i.e.: x - y)", - "parameters": { + ToolDefinition { + name: "subtract".to_string(), + description: "Subtract y from x (i.e.: x - y)".to_string(), + parameters: json!({ "type": "object", "properties": { "x": { @@ -80,9 +80,8 @@ impl Tool for Subtract { } }, "required": ["x", "y"], - }, - })) - .expect("Tool Definition") + }), + } } async fn call(&self, args: Self::Args) -> Result { @@ -100,7 +99,7 @@ async fn main() -> Result<(), anyhow::Error> { .init(); // Create OpenAI client - let openai_client = providers::openai::Client::from_env(); + let openai_client = providers::openai::Client::from_env()?; // Create agent with a single context prompt and two tools let calculator_agent = openai_client diff --git a/rig/rig-core/examples/agent_with_context.rs b/rig/rig-core/examples/agent_with_context.rs index 5a420cd17..be06166dd 100644 --- a/rig/rig-core/examples/agent_with_context.rs +++ b/rig/rig-core/examples/agent_with_context.rs @@ -18,7 +18,7 @@ const CONTEXT_PROMPT: &str = "What does \"glarb-glarb\" mean?"; #[tokio::main] async fn main() -> Result<()> { - let client = cohere::Client::from_env(); + let client = cohere::Client::from_env()?; let model = client.completion_model(COMMAND_R); let agent = CONTEXT_DOCS .iter() diff --git a/rig/rig-core/examples/agent_with_default_max_turns.rs b/rig/rig-core/examples/agent_with_default_max_turns.rs index c2adbc0f4..5763fbd8d 100644 --- a/rig/rig-core/examples/agent_with_default_max_turns.rs +++ b/rig/rig-core/examples/agent_with_default_max_turns.rs @@ -82,7 +82,7 @@ const PROMPT: &str = "Calculate (3 + 5) / 4 and describe the result."; #[tokio::main] async fn main() -> Result<()> { - let agent = anthropic::Client::from_env() + let agent = anthropic::Client::from_env()? .agent(anthropic::completion::CLAUDE_SONNET_4_6) .preamble( "You are an assistant that must use the available tools for arithmetic. \ diff --git a/rig/rig-core/examples/agent_with_echochambers.rs b/rig/rig-core/examples/agent_with_echochambers.rs index c05ae5553..fff308d89 100644 --- a/rig/rig-core/examples/agent_with_echochambers.rs +++ b/rig/rig-core/examples/agent_with_echochambers.rs @@ -303,11 +303,10 @@ impl Tool for GetMetricsHistory { #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Get API keys from environment - let echochambers_api_key = - env::var("ECHOCHAMBERS_API_KEY").expect("ECHOCHAMBERS_API_KEY not set"); + let echochambers_api_key = env::var("ECHOCHAMBERS_API_KEY")?; // Create OpenAI client - let openai_client = Client::from_env(); + let openai_client = Client::from_env()?; // Create agent with all tools let echochambers_agent = openai_client diff --git a/rig/rig-core/examples/agent_with_loaders.rs b/rig/rig-core/examples/agent_with_loaders.rs index 3266baee3..1e61063d2 100644 --- a/rig/rig-core/examples/agent_with_loaders.rs +++ b/rig/rig-core/examples/agent_with_loaders.rs @@ -21,7 +21,7 @@ fn load_example_contexts() -> Result Result<()> { - let client = openai::Client::from_env(); + let client = openai::Client::from_env()?; let model = client.completion_model(openai::GPT_4O); let files = load_example_contexts()?; diff --git a/rig/rig-core/examples/agent_with_tools.rs b/rig/rig-core/examples/agent_with_tools.rs index f073b1ad2..76fe14db1 100644 --- a/rig/rig-core/examples/agent_with_tools.rs +++ b/rig/rig-core/examples/agent_with_tools.rs @@ -85,7 +85,7 @@ fn boxed_tools() -> Vec> { #[tokio::main] async fn main() -> Result<()> { - let agent = openai::Client::from_env() + let agent = openai::Client::from_env()? .agent(openai::GPT_4O) .preamble( "You are a calculator here to help the user perform arithmetic operations. \ diff --git a/rig/rig-core/examples/agent_with_tools_otel.rs b/rig/rig-core/examples/agent_with_tools_otel.rs index 3394d643b..22f75ce8f 100644 --- a/rig/rig-core/examples/agent_with_tools_otel.rs +++ b/rig/rig-core/examples/agent_with_tools_otel.rs @@ -79,10 +79,10 @@ impl Tool for Subtract { type Output = i32; async fn definition(&self, _prompt: String) -> ToolDefinition { - serde_json::from_value(json!({ - "name": "subtract", - "description": "Subtract y from x (i.e.: x - y)", - "parameters": { + ToolDefinition { + name: "subtract".to_string(), + description: "Subtract y from x (i.e.: x - y)".to_string(), + parameters: json!({ "type": "object", "properties": { "x": { @@ -95,9 +95,8 @@ impl Tool for Subtract { } }, "required": ["x", "y"], - }, - })) - .expect("Tool Definition") + }), + } } async fn call(&self, args: Self::Args) -> Result { @@ -141,7 +140,7 @@ async fn main() -> Result<(), anyhow::Error> { .init(); // Create OpenAI client - let openai_client = providers::openai::Client::from_env(); + let openai_client = providers::openai::Client::from_env()?; // Create agent with a single context prompt and two tools let calculator_agent = openai_client diff --git a/rig/rig-core/examples/calculator_chatbot.rs b/rig/rig-core/examples/calculator_chatbot.rs index ded4c1b1d..52476714c 100644 --- a/rig/rig-core/examples/calculator_chatbot.rs +++ b/rig/rig-core/examples/calculator_chatbot.rs @@ -37,10 +37,10 @@ impl Tool for Add { type Output = i32; async fn definition(&self, _prompt: String) -> ToolDefinition { - serde_json::from_value(json!({ - "name": "add", - "description": "Add x and y together", - "parameters": { + ToolDefinition { + name: "add".to_string(), + description: "Add x and y together".to_string(), + parameters: json!({ "type": "object", "properties": { "x": { @@ -53,9 +53,8 @@ impl Tool for Add { } }, "required": [ "x", "y" ] - } - })) - .expect("Tool Definition") + }), + } } async fn call(&self, args: Self::Args) -> Result { @@ -90,10 +89,10 @@ impl Tool for Subtract { type Output = i32; async fn definition(&self, _prompt: String) -> ToolDefinition { - serde_json::from_value(json!({ - "name": "subtract", - "description": "Subtract y from x (i.e.: x - y)", - "parameters": { + ToolDefinition { + name: "subtract".to_string(), + description: "Subtract y from x (i.e.: x - y)".to_string(), + parameters: json!({ "type": "object", "properties": { "x": { @@ -106,9 +105,8 @@ impl Tool for Subtract { } }, "required": [ "x", "y" ] - } - })) - .expect("Tool Definition") + }), + } } async fn call(&self, args: Self::Args) -> Result { @@ -142,10 +140,10 @@ impl Tool for Multiply { type Output = i32; async fn definition(&self, _prompt: String) -> ToolDefinition { - serde_json::from_value(json!({ - "name": "multiply", - "description": "Compute the product of x and y (i.e.: x * y)", - "parameters": { + ToolDefinition { + name: "multiply".to_string(), + description: "Compute the product of x and y (i.e.: x * y)".to_string(), + parameters: json!({ "type": "object", "properties": { "x": { @@ -158,9 +156,8 @@ impl Tool for Multiply { } }, "required": [ "x", "y" ] - } - })) - .expect("Tool Definition") + }), + } } async fn call(&self, args: Self::Args) -> Result { @@ -190,10 +187,11 @@ impl Tool for Divide { type Args = OperationArgs; type Output = i32; async fn definition(&self, _prompt: String) -> ToolDefinition { - serde_json::from_value(json!({ - "name": "divide", - "description": "Compute the Quotient of x and y (i.e.: x / y). Useful for ratios.", - "parameters": { + ToolDefinition { + name: "divide".to_string(), + description: "Compute the Quotient of x and y (i.e.: x / y). Useful for ratios." + .to_string(), + parameters: json!({ "type": "object", "properties": { "x": { @@ -206,9 +204,8 @@ impl Tool for Divide { } }, "required": [ "x", "y" ] - } - })) - .expect("Tool Definition") + }), + } } async fn call(&self, args: Self::Args) -> Result { let result = args.x / args.y; @@ -235,7 +232,7 @@ impl ToolEmbedding for Divide { #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Create OpenAI client - let openai_client = Client::from_env(); + let openai_client = Client::from_env()?; // Create dynamic tools embeddings let toolset = ToolSet::builder() diff --git a/rig/rig-core/examples/chain.rs b/rig/rig-core/examples/chain.rs index 2e4af7fd5..2f749a83b 100644 --- a/rig/rig-core/examples/chain.rs +++ b/rig/rig-core/examples/chain.rs @@ -49,7 +49,7 @@ fn lookup_context(docs: Vec<(f64, String, String)>, prompt: &str) -> String { #[tokio::main] async fn main() -> Result<(), anyhow::Error> { tracing_subscriber::fmt().init(); - let client = Client::from_env(); + let client = Client::from_env()?; let embedding_model = client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); let mut builder = EmbeddingsBuilder::new(embedding_model.clone()); diff --git a/rig/rig-core/examples/complex_agentic_loop_claude.rs b/rig/rig-core/examples/complex_agentic_loop_claude.rs index 62a92fedb..c4820ce9d 100644 --- a/rig/rig-core/examples/complex_agentic_loop_claude.rs +++ b/rig/rig-core/examples/complex_agentic_loop_claude.rs @@ -26,12 +26,12 @@ async fn main() -> Result<(), anyhow::Error> { .init(); // Create Anthropic client - let anthropic_api_key = env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"); + let anthropic_api_key = env::var("ANTHROPIC_API_KEY")?; let anthropic_client = Client::builder().api_key(&anthropic_api_key).build()?; // Create the embedding model for our vector store // We'll use OpenAI's embedding model for this example - let openai_client = rig::providers::openai::Client::from_env(); + let openai_client = rig::providers::openai::Client::from_env()?; let embedding_model = openai_client.embedding_model(rig::providers::openai::TEXT_EMBEDDING_ADA_002); @@ -183,14 +183,12 @@ async fn main() -> Result<(), anyhow::Error> { Message::User { content } => println!( "\nUser [{}]: {}", i, - serde_json::to_string_pretty(&content) - .expect("Failed to serialize user message") + serde_json::to_string_pretty(&content)? ), Message::Assistant { content, .. } => println!( "Assistant [{}]: {}", i, - serde_json::to_string_pretty(&content) - .expect("Failed to serialize assistant message") + serde_json::to_string_pretty(&content)? ), _ => { // Ignore other message types - the only other type of message that exists is system messages diff --git a/rig/rig-core/examples/custom_vector_store.rs b/rig/rig-core/examples/custom_vector_store.rs index c4ab49fe9..998d20657 100644 --- a/rig/rig-core/examples/custom_vector_store.rs +++ b/rig/rig-core/examples/custom_vector_store.rs @@ -130,7 +130,8 @@ impl VectorStoreIndex for RedisVectorStore { .map(serde_json::from_str) .transpose() .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))? - .unwrap_or_else(|| serde_json::from_str("{}").unwrap()); + .map_or_else(|| serde_json::from_str("{}"), Ok) + .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; output.push((score, id, metadata)); } @@ -180,7 +181,7 @@ struct Document { #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client from environment - let openai_client = openai::Client::from_env(); + let openai_client = openai::Client::from_env()?; // Convert it to an EmbeddingModel let embedding_model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); diff --git a/rig/rig-core/examples/debate.rs b/rig/rig-core/examples/debate.rs index 35ad8083e..7f0bde57e 100644 --- a/rig/rig-core/examples/debate.rs +++ b/rig/rig-core/examples/debate.rs @@ -13,15 +13,15 @@ struct Debater { } impl Debater { - fn new(position_a: &str, position_b: &str) -> Self { + fn new(position_a: &str, position_b: &str) -> Result { tracing_subscriber::fmt() .with_max_level(tracing::Level::INFO) .with_target(false) .init(); - let openai_client = openai::Client::from_env(); - let cohere_client = cohere::Client::from_env(); + let openai_client = openai::Client::from_env()?; + let cohere_client = cohere::Client::from_env()?; - Self { + Ok(Self { gpt_4: openai_client .agent(openai::GPT_4) .preamble(position_a) @@ -30,7 +30,7 @@ impl Debater { .agent(cohere::COMMAND_R) .preamble(position_b) .build(), - } + }) } async fn rounds(&self, n: usize) -> Result<()> { @@ -88,7 +88,7 @@ async fn main() -> Result<(), anyhow::Error> { You choose what your arguments are. \ I will argue against you and you must rebuke me and try to convince me that I am wrong. \ Make your statements short and concise.", - ); + )?; // Run the debate for 4 rounds debator.rounds(4).await?; diff --git a/rig/rig-core/examples/discord_bot.rs b/rig/rig-core/examples/discord_bot.rs index 4e8054685..c9a28c263 100644 --- a/rig/rig-core/examples/discord_bot.rs +++ b/rig/rig-core/examples/discord_bot.rs @@ -4,10 +4,9 @@ use rig::providers::openai; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { - let discord_bot_token = std::env::var("DISCORD_BOT_TOKEN") - .expect("DISCORD_BOT_TOKEN to be set as an environment variable"); + let discord_bot_token = std::env::var("DISCORD_BOT_TOKEN")?; // Create OpenAI client - let client = rig::providers::openai::Client::from_env(); + let client = rig::providers::openai::Client::from_env()?; // Create agent with a single context prompt let mut discord_bot = client @@ -15,9 +14,9 @@ async fn main() -> Result<(), anyhow::Error> { .preamble("You are a helpful assistant.") .build() .into_discord_bot(&discord_bot_token) - .await; + .await?; - discord_bot.start().await.unwrap(); + discord_bot.start().await?; Ok(()) } diff --git a/rig/rig-core/examples/enum_dispatch.rs b/rig/rig-core/examples/enum_dispatch.rs index 8df3b9989..6d69f0f92 100644 --- a/rig/rig-core/examples/enum_dispatch.rs +++ b/rig/rig-core/examples/enum_dispatch.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; +use anyhow::{Result, anyhow}; use rig::agent::Agent; use rig::client::{CompletionClient, ProviderClient}; use rig::completion::{Prompt, PromptError}; @@ -28,77 +29,83 @@ struct AgentConfig<'a> { // In production you would likely want to create some sort of `RegistryKey` type instead of // allowing arbitrary strings, for improved type safety -struct ProviderRegistry(HashMap<&'static str, fn(AgentConfig) -> Agents>); +struct ProviderRegistry(HashMap<&'static str, fn(AgentConfig<'_>) -> Result>); -fn anthropic_agent(AgentConfig { name, preamble }: AgentConfig) -> Agents { - let agent = anthropic::Client::from_env() +fn anthropic_agent(AgentConfig { name, preamble }: AgentConfig<'_>) -> Result { + let agent = anthropic::Client::from_env()? .agent(CLAUDE_SONNET_4_6) .name(name) .preamble(preamble) .build(); - Agents::Anthropic(agent) + Ok(Agents::Anthropic(agent)) } -fn openai_agent(AgentConfig { name, preamble }: AgentConfig) -> Agents { - let agent = openai::Client::from_env() +fn openai_agent(AgentConfig { name, preamble }: AgentConfig<'_>) -> Result { + let agent = openai::Client::from_env()? .completions_api() .agent(GPT_4O) .name(name) .preamble(preamble) .build(); - Agents::OpenAI(agent) + Ok(Agents::OpenAI(agent)) } impl ProviderRegistry { pub fn new() -> Self { Self(HashMap::from_iter([ - ("anthropic", anthropic_agent as fn(AgentConfig) -> Agents), - ("openai", openai_agent as fn(AgentConfig) -> Agents), + ( + "anthropic", + anthropic_agent as fn(AgentConfig<'_>) -> Result, + ), + ( + "openai", + openai_agent as fn(AgentConfig<'_>) -> Result, + ), ])) } - pub fn agent(&self, provider: &str, agent_config: AgentConfig) -> Option { - self.0.get(provider).map(|p| p(agent_config)) + pub fn agent(&self, provider: &str, agent_config: AgentConfig<'_>) -> Result { + let builder = self + .0 + .get(provider) + .ok_or_else(|| anyhow!("unknown provider: {provider}"))?; + builder(agent_config) } } #[tokio::main] -async fn main() { +async fn main() -> Result<()> { let registry = ProviderRegistry::new(); - let openai_agent = registry - .agent( - "openai", - AgentConfig { - name: "Assistant", - preamble: "You are a helpful assistant", - }, - ) - .unwrap(); - - let anthropic_agent = registry - .agent( - "anthropic", - AgentConfig { - name: "Assistant", - preamble: "You are an unhelpful assistant", - }, - ) - .unwrap(); + let openai_agent = registry.agent( + "openai", + AgentConfig { + name: "Assistant", + preamble: "You are a helpful assistant", + }, + )?; + + let anthropic_agent = registry.agent( + "anthropic", + AgentConfig { + name: "Assistant", + preamble: "You are an unhelpful assistant", + }, + )?; let oai_response = openai_agent .prompt("How much does 4oz of parmesan cheese weigh") - .await - .unwrap(); + .await?; println!("Helpful: {oai_response}"); let anthropic_response = anthropic_agent .prompt("How much does 4oz of parmesan cheese weigh") - .await - .unwrap(); + .await?; println!("Unhelpful: {anthropic_response}"); + + Ok(()) } diff --git a/rig/rig-core/examples/extractor.rs b/rig/rig-core/examples/extractor.rs index 6b05c5989..6b04d2cae 100644 --- a/rig/rig-core/examples/extractor.rs +++ b/rig/rig-core/examples/extractor.rs @@ -23,7 +23,7 @@ const SECOND_INPUT: &str = "Jane Smith is a data scientist."; #[tokio::main] async fn main() -> Result<()> { - let client = openai::Client::from_env(); + let client = openai::Client::from_env()?; let extractor = client.extractor::(openai::GPT_4).build(); let person = extractor.extract(FIRST_INPUT).await?; diff --git a/rig/rig-core/examples/gemini_deep_research.rs b/rig/rig-core/examples/gemini_deep_research.rs index f5c5bb51b..ca15c6710 100644 --- a/rig/rig-core/examples/gemini_deep_research.rs +++ b/rig/rig-core/examples/gemini_deep_research.rs @@ -132,7 +132,7 @@ async fn main() -> Result<()> { .init(); let use_streaming = std::env::args().any(|arg| arg == "--stream"); - let client = gemini::Client::from_env().interactions_api(); + let client = gemini::Client::from_env()?.interactions_api(); let model = client.completion_model("gemini-3-pro-preview"); let prompt = "Research the history of Google TPUs."; diff --git a/rig/rig-core/examples/gemini_extractor_with_rag.rs b/rig/rig-core/examples/gemini_extractor_with_rag.rs index e6e83ced1..796432bfd 100644 --- a/rig/rig-core/examples/gemini_extractor_with_rag.rs +++ b/rig/rig-core/examples/gemini_extractor_with_rag.rs @@ -63,7 +63,7 @@ async fn main() -> Result<(), anyhow::Error> { .init(); // Create Gemini client - let gemini_client = Client::from_env(); + let gemini_client = Client::from_env()?; let embedding_model = gemini_client.embedding_model(gemini::EMBEDDING_001); // Generate embeddings for the definitions of all the documents using the specified embedding model. diff --git a/rig/rig-core/examples/gemini_video_understanding.rs b/rig/rig-core/examples/gemini_video_understanding.rs index 65d9085b7..14e7694dc 100644 --- a/rig/rig-core/examples/gemini_video_understanding.rs +++ b/rig/rig-core/examples/gemini_video_understanding.rs @@ -47,7 +47,7 @@ fn build_additional_params() -> Result { #[tokio::main] async fn main() -> Result<()> { - let client = gemini::Client::from_env(); + let client = gemini::Client::from_env()?; let additional_params = build_additional_params()?; let agent = client .agent(MODEL) diff --git a/rig/rig-core/examples/manual_tool_calls.rs b/rig/rig-core/examples/manual_tool_calls.rs index 7dc09fbe3..cb323204f 100644 --- a/rig/rig-core/examples/manual_tool_calls.rs +++ b/rig/rig-core/examples/manual_tool_calls.rs @@ -116,7 +116,7 @@ fn tool_result_message(tool_call: &ToolCall, output: String) -> Message { async fn main() -> Result<()> { const MAX_ROUNDS: usize = 8; - let agent = openai::Client::from_env() + let agent = openai::Client::from_env()? .agent(openai::GPT_4O_MINI) .preamble( "You are a calculator. Never do arithmetic from memory. \ diff --git a/rig/rig-core/examples/multi_agent.rs b/rig/rig-core/examples/multi_agent.rs index a71fb3d83..7cdab14b4 100644 --- a/rig/rig-core/examples/multi_agent.rs +++ b/rig/rig-core/examples/multi_agent.rs @@ -29,21 +29,22 @@ impl Tool for TranslatorTool { type Output = String; async fn definition(&self, _prompt: String) -> ToolDefinition { - serde_json::from_value(json!({ - "name": Self::NAME, - "description": "Translate any text to English. If already in English, fix grammar and syntax issues.", - "parameters": { + ToolDefinition { + name: Self::NAME.to_string(), + description: + "Translate any text to English. If already in English, fix grammar and syntax issues." + .to_string(), + parameters: json!({ "type": "object", "properties": { "prompt": { "type": "string", "description": "The text to translate to English" - }, + } }, "required": ["prompt"] - } - })) - .expect("Tool Definition") + }), + } } async fn call(&self, args: Self::Args) -> Result { @@ -66,7 +67,7 @@ impl Tool for TranslatorTool { #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Create OpenAI client - let openai_client = OpenAIClient::from_env(); + let openai_client = OpenAIClient::from_env()?; let model = openai_client.completion_model(openai::GPT_4O); let translator_agent = AgentBuilder::new(model.clone()) diff --git a/rig/rig-core/examples/multi_extract.rs b/rig/rig-core/examples/multi_extract.rs index 6347a0789..0c4887ba7 100644 --- a/rig/rig-core/examples/multi_extract.rs +++ b/rig/rig-core/examples/multi_extract.rs @@ -36,7 +36,7 @@ fn sample_inputs() -> Vec<&'static str> { #[tokio::main] async fn main() -> Result<()> { - let client = openai::Client::from_env(); + let client = openai::Client::from_env()?; let names_extractor = client .extractor::(openai::GPT_4O_MINI) .preamble("Extract names from the given text.") diff --git a/rig/rig-core/examples/multi_turn_agent.rs b/rig/rig-core/examples/multi_turn_agent.rs index 4f546e835..2b6c69442 100644 --- a/rig/rig-core/examples/multi_turn_agent.rs +++ b/rig/rig-core/examples/multi_turn_agent.rs @@ -15,7 +15,7 @@ async fn main() -> anyhow::Result<()> { .init(); // Create OpenAI client - let openai_client = anthropic::Client::from_env(); + let openai_client = anthropic::Client::from_env()?; // Create RAG agent with a single context prompt and a dynamic tool source let agent = openai_client @@ -73,10 +73,10 @@ impl Tool for Add { type Output = i32; async fn definition(&self, _prompt: String) -> ToolDefinition { - serde_json::from_value(json!({ - "name": "add", - "description": "Add x and y together", - "parameters": { + ToolDefinition { + name: "add".to_string(), + description: "Add x and y together".to_string(), + parameters: json!({ "type": "object", "properties": { "x": { @@ -88,9 +88,8 @@ impl Tool for Add { "description": "The second number to add" } } - } - })) - .expect("Tool Definition") + }), + } } async fn call(&self, args: Self::Args) -> Result { @@ -109,10 +108,10 @@ impl Tool for Subtract { type Output = i32; async fn definition(&self, _prompt: String) -> ToolDefinition { - serde_json::from_value(json!({ - "name": "subtract", - "description": "Subtract y from x (i.e.: x - y)", - "parameters": { + ToolDefinition { + name: "subtract".to_string(), + description: "Subtract y from x (i.e.: x - y)".to_string(), + parameters: json!({ "type": "object", "properties": { "x": { @@ -124,9 +123,8 @@ impl Tool for Subtract { "description": "The number to subtract" } } - } - })) - .expect("Tool Definition") + }), + } } async fn call(&self, args: Self::Args) -> Result { @@ -145,10 +143,10 @@ impl Tool for Multiply { type Output = i32; async fn definition(&self, _prompt: String) -> ToolDefinition { - serde_json::from_value(json!({ - "name": "multiply", - "description": "Compute the product of x and y (i.e.: x * y)", - "parameters": { + ToolDefinition { + name: "multiply".to_string(), + description: "Compute the product of x and y (i.e.: x * y)".to_string(), + parameters: json!({ "type": "object", "properties": { "x": { @@ -160,9 +158,8 @@ impl Tool for Multiply { "description": "The second factor in the product" } } - } - })) - .expect("Tool Definition") + }), + } } async fn call(&self, args: Self::Args) -> Result { @@ -181,10 +178,11 @@ impl Tool for Divide { type Output = i32; async fn definition(&self, _prompt: String) -> ToolDefinition { - serde_json::from_value(json!({ - "name": "divide", - "description": "Compute the Quotient of x and y (i.e.: x / y). Useful for ratios.", - "parameters": { + ToolDefinition { + name: "divide".to_string(), + description: "Compute the Quotient of x and y (i.e.: x / y). Useful for ratios." + .to_string(), + parameters: json!({ "type": "object", "properties": { "x": { @@ -196,9 +194,8 @@ impl Tool for Divide { "description": "The Divisor of the division. The number by which the dividend is being divided" } } - } - })) - .expect("Tool Definition") + }), + } } async fn call(&self, args: Self::Args) -> Result { diff --git a/rig/rig-core/examples/multi_turn_agent_extended.rs b/rig/rig-core/examples/multi_turn_agent_extended.rs index a88ac30a6..2c4a761d7 100644 --- a/rig/rig-core/examples/multi_turn_agent_extended.rs +++ b/rig/rig-core/examples/multi_turn_agent_extended.rs @@ -15,7 +15,7 @@ async fn main() -> anyhow::Result<()> { .init(); // Create OpenAI client - let openai_client = anthropic::Client::from_env(); + let openai_client = anthropic::Client::from_env()?; // Create RAG agent with a single context prompt and a dynamic tool source let agent = openai_client @@ -75,10 +75,10 @@ impl Tool for Add { type Output = i32; async fn definition(&self, _prompt: String) -> ToolDefinition { - serde_json::from_value(json!({ - "name": "add", - "description": "Add x and y together", - "parameters": { + ToolDefinition { + name: "add".to_string(), + description: "Add x and y together".to_string(), + parameters: json!({ "type": "object", "properties": { "x": { @@ -90,9 +90,8 @@ impl Tool for Add { "description": "The second number to add" } } - } - })) - .expect("Tool Definition") + }), + } } async fn call(&self, args: Self::Args) -> Result { @@ -111,10 +110,10 @@ impl Tool for Subtract { type Output = i32; async fn definition(&self, _prompt: String) -> ToolDefinition { - serde_json::from_value(json!({ - "name": "subtract", - "description": "Subtract y from x (i.e.: x - y)", - "parameters": { + ToolDefinition { + name: "subtract".to_string(), + description: "Subtract y from x (i.e.: x - y)".to_string(), + parameters: json!({ "type": "object", "properties": { "x": { @@ -126,9 +125,8 @@ impl Tool for Subtract { "description": "The number to subtract" } } - } - })) - .expect("Tool Definition") + }), + } } async fn call(&self, args: Self::Args) -> Result { @@ -147,10 +145,10 @@ impl Tool for Multiply { type Output = i32; async fn definition(&self, _prompt: String) -> ToolDefinition { - serde_json::from_value(json!({ - "name": "multiply", - "description": "Compute the product of x and y (i.e.: x * y)", - "parameters": { + ToolDefinition { + name: "multiply".to_string(), + description: "Compute the product of x and y (i.e.: x * y)".to_string(), + parameters: json!({ "type": "object", "properties": { "x": { @@ -162,9 +160,8 @@ impl Tool for Multiply { "description": "The second factor in the product" } } - } - })) - .expect("Tool Definition") + }), + } } async fn call(&self, args: Self::Args) -> Result { @@ -183,10 +180,11 @@ impl Tool for Divide { type Output = i32; async fn definition(&self, _prompt: String) -> ToolDefinition { - serde_json::from_value(json!({ - "name": "divide", - "description": "Compute the Quotient of x and y (i.e.: x / y). Useful for ratios.", - "parameters": { + ToolDefinition { + name: "divide".to_string(), + description: "Compute the Quotient of x and y (i.e.: x / y). Useful for ratios." + .to_string(), + parameters: json!({ "type": "object", "properties": { "x": { @@ -198,9 +196,8 @@ impl Tool for Divide { "description": "The Divisor of the division. The number by which the dividend is being divided" } } - } - })) - .expect("Tool Definition") + }), + } } async fn call(&self, args: Self::Args) -> Result { diff --git a/rig/rig-core/examples/openai_agent_completions_api_otel.rs b/rig/rig-core/examples/openai_agent_completions_api_otel.rs index c09c3fbd1..a57a46088 100644 --- a/rig/rig-core/examples/openai_agent_completions_api_otel.rs +++ b/rig/rig-core/examples/openai_agent_completions_api_otel.rs @@ -43,14 +43,14 @@ async fn main() -> Result<(), anyhow::Error> { .init(); // Create OpenAI client - let agent = providers::openai::Client::from_env() + let agent = providers::openai::Client::from_env()? .completion_model(openai::GPT_4O) .completions_api() .into_agent_builder() .preamble("You are a helpful assistant") .build(); - let res = agent.prompt("Hello world!").await.unwrap(); + let res = agent.prompt("Hello world!").await?; println!("GPT-4o: {res}"); diff --git a/rig/rig-core/examples/openai_streaming_with_tools_otel.rs b/rig/rig-core/examples/openai_streaming_with_tools_otel.rs index 0e21f5973..ee2499760 100644 --- a/rig/rig-core/examples/openai_streaming_with_tools_otel.rs +++ b/rig/rig-core/examples/openai_streaming_with_tools_otel.rs @@ -70,10 +70,10 @@ impl Tool for Subtract { type Output = i32; async fn definition(&self, _prompt: String) -> ToolDefinition { - serde_json::from_value(json!({ - "name": "subtract", - "description": "Subtract y from x (i.e.: x - y)", - "parameters": { + ToolDefinition { + name: "subtract".to_string(), + description: "Subtract y from x (i.e.: x - y)".to_string(), + parameters: json!({ "type": "object", "properties": { "x": { @@ -86,9 +86,8 @@ impl Tool for Subtract { } }, "required": ["x", "y"], - } - })) - .expect("Tool Definition") + }), + } } async fn call(&self, args: Self::Args) -> Result { @@ -127,7 +126,7 @@ async fn main() -> Result<(), anyhow::Error> { .init(); // Create agent with a single context prompt and two tools - let calculator_agent = providers::openai::Client::from_env() + let calculator_agent = providers::openai::Client::from_env()? .agent(providers::openai::GPT_4O) .preamble( "You are a calculator here to help the user perform arithmetic diff --git a/rig/rig-core/examples/pdf_agent.rs b/rig/rig-core/examples/pdf_agent.rs index 278896d0b..116f10bf7 100644 --- a/rig/rig-core/examples/pdf_agent.rs +++ b/rig/rig-core/examples/pdf_agent.rs @@ -59,8 +59,7 @@ async fn main() -> Result<()> { let client = ollama::Client::builder() .api_key(Nothing) .base_url("http://localhost:11434/v1") - .build() - .unwrap(); + .build()?; // Load PDFs using Rig's built-in PDF loader let documents_dir = std::env::current_dir()?.join("rig-core/examples/documents"); diff --git a/rig/rig-core/examples/rag.rs b/rig/rig-core/examples/rag.rs index 25404eb14..fdf250116 100644 --- a/rig/rig-core/examples/rag.rs +++ b/rig/rig-core/examples/rag.rs @@ -27,7 +27,7 @@ async fn main() -> Result<(), anyhow::Error> { .init(); // Create OpenAI client - let openai_client = Client::from_env(); + let openai_client = Client::from_env()?; let embedding_model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); // Generate embeddings for the definitions of all the documents using the specified embedding model. diff --git a/rig/rig-core/examples/rag_dynamic_tools.rs b/rig/rig-core/examples/rag_dynamic_tools.rs index c07cf06d4..0ed4dab88 100644 --- a/rig/rig-core/examples/rag_dynamic_tools.rs +++ b/rig/rig-core/examples/rag_dynamic_tools.rs @@ -35,10 +35,10 @@ impl Tool for Add { type Output = i32; async fn definition(&self, _prompt: String) -> ToolDefinition { - serde_json::from_value(json!({ - "name": "add", - "description": "Add x and y together", - "parameters": { + ToolDefinition { + name: "add".to_string(), + description: "Add x and y together".to_string(), + parameters: json!({ "type": "object", "properties": { "x": { @@ -50,9 +50,8 @@ impl Tool for Add { "description": "The second number to add" } } - } - })) - .expect("Tool Definition") + }), + } } async fn call(&self, args: Self::Args) -> Result { let result = args.x + args.y; @@ -79,10 +78,10 @@ impl Tool for Subtract { type Args = OperationArgs; type Output = i32; async fn definition(&self, _prompt: String) -> ToolDefinition { - serde_json::from_value(json!({ - "name": "subtract", - "description": "Subtract y from x (i.e.: x - y)", - "parameters": { + ToolDefinition { + name: "subtract".to_string(), + description: "Subtract y from x (i.e.: x - y)".to_string(), + parameters: json!({ "type": "object", "properties": { "x": { @@ -94,9 +93,8 @@ impl Tool for Subtract { "description": "The number to subtract" } } - } - })) - .expect("Tool Definition") + }), + } } async fn call(&self, args: Self::Args) -> Result { @@ -130,7 +128,7 @@ async fn main() -> Result<(), anyhow::Error> { .init(); // Create OpenAI client - let openai_client = Client::from_env(); + let openai_client = Client::from_env()?; let embedding_model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); let toolset = ToolSet::builder() .dynamic_tool(Add) diff --git a/rig/rig-core/examples/rag_dynamic_tools_multi_turn.rs b/rig/rig-core/examples/rag_dynamic_tools_multi_turn.rs index 8cf9d4931..9aacdfaff 100644 --- a/rig/rig-core/examples/rag_dynamic_tools_multi_turn.rs +++ b/rig/rig-core/examples/rag_dynamic_tools_multi_turn.rs @@ -35,10 +35,10 @@ impl Tool for Add { type Output = i32; async fn definition(&self, _prompt: String) -> ToolDefinition { - serde_json::from_value(json!({ - "name": "add", - "description": "Add x and y together", - "parameters": { + ToolDefinition { + name: "add".to_string(), + description: "Add x and y together".to_string(), + parameters: json!({ "type": "object", "properties": { "x": { @@ -50,9 +50,8 @@ impl Tool for Add { "description": "The second number to add" } } - } - })) - .expect("Tool Definition") + }), + } } async fn call(&self, args: Self::Args) -> Result { @@ -88,10 +87,10 @@ impl Tool for Subtract { type Output = i32; async fn definition(&self, _prompt: String) -> ToolDefinition { - serde_json::from_value(json!({ - "name": "subtract", - "description": "Subtract y from x (i.e.: x - y)", - "parameters": { + ToolDefinition { + name: "subtract".to_string(), + description: "Subtract y from x (i.e.: x - y)".to_string(), + parameters: json!({ "type": "object", "properties": { "x": { @@ -103,9 +102,8 @@ impl Tool for Subtract { "description": "The number to subtract" } } - } - })) - .expect("Tool Definition") + }), + } } async fn call(&self, args: Self::Args) -> Result { @@ -140,7 +138,7 @@ async fn main() -> Result<(), anyhow::Error> { .init(); // Create OpenAI client - let openai_client = Client::from_env(); + let openai_client = Client::from_env()?; let embedding_model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); diff --git a/rig/rig-core/examples/rag_ollama.rs b/rig/rig-core/examples/rag_ollama.rs index 50ce07fb7..e17a62b66 100644 --- a/rig/rig-core/examples/rag_ollama.rs +++ b/rig/rig-core/examples/rag_ollama.rs @@ -26,7 +26,7 @@ async fn main() -> Result<(), anyhow::Error> { .init(); // Create ollama client - let ollama_client = Client::from_val(Nothing.into()); + let ollama_client = Client::from_val(Nothing.into())?; let embedding_model = ollama_client.embedding_model("nomic-embed-text"); // Generate embeddings for the definitions of all the documents using the specified embedding model. diff --git a/rig/rig-core/examples/reasoning_loop.rs b/rig/rig-core/examples/reasoning_loop.rs index a746283f9..aae43bb3d 100644 --- a/rig/rig-core/examples/reasoning_loop.rs +++ b/rig/rig-core/examples/reasoning_loop.rs @@ -57,7 +57,8 @@ impl Prompt for ReasoningAgent { let history_vec: Vec<_> = messages.clone().into_iter().collect(); tracing::info!( "full chat history generated: {}", - serde_json::to_string_pretty(&history_vec).unwrap() + serde_json::to_string_pretty(&history_vec) + .unwrap_or_else(|_| "".to_string()) ); } Ok(response.output) @@ -72,7 +73,7 @@ async fn main() -> anyhow::Result<()> { .init(); // Create Anthropic client - let anthropic_client = anthropic::Client::from_env(); + let anthropic_client = anthropic::Client::from_env()?; let agent = ReasoningAgent { chain_of_thought_extractor: anthropic_client .extractor(anthropic::completion::CLAUDE_SONNET_4_6) @@ -127,10 +128,10 @@ impl Tool for Add { type Output = i32; async fn definition(&self, _prompt: String) -> ToolDefinition { - serde_json::from_value(json!({ - "name": "add", - "description": "Add x and y together", - "parameters": { + ToolDefinition { + name: "add".to_string(), + description: "Add x and y together".to_string(), + parameters: json!({ "type": "object", "properties": { "x": { @@ -142,9 +143,8 @@ impl Tool for Add { "description": "The second number to add" } } - } - })) - .expect("Tool Definition") + }), + } } async fn call(&self, args: Self::Args) -> Result { let result = args.x + args.y; @@ -159,10 +159,10 @@ impl Tool for Subtract { type Args = OperationArgs; type Output = i32; async fn definition(&self, _prompt: String) -> ToolDefinition { - serde_json::from_value(json!({ - "name": "subtract", - "description": "Subtract y from x (i.e.: x - y)", - "parameters": { + ToolDefinition { + name: "subtract".to_string(), + description: "Subtract y from x (i.e.: x - y)".to_string(), + parameters: json!({ "type": "object", "properties": { "x": { @@ -174,9 +174,8 @@ impl Tool for Subtract { "description": "The number to subtract" } } - } - })) - .expect("Tool Definition") + }), + } } async fn call(&self, args: Self::Args) -> Result { @@ -194,10 +193,10 @@ impl Tool for Multiply { type Output = i32; async fn definition(&self, _prompt: String) -> ToolDefinition { - serde_json::from_value(json!({ - "name": "multiply", - "description": "Compute the product of x and y (i.e.: x * y)", - "parameters": { + ToolDefinition { + name: "multiply".to_string(), + description: "Compute the product of x and y (i.e.: x * y)".to_string(), + parameters: json!({ "type": "object", "properties": { "x": { @@ -209,9 +208,8 @@ impl Tool for Multiply { "description": "The second factor in the product" } } - } - })) - .expect("Tool Definition") + }), + } } async fn call(&self, args: Self::Args) -> Result { @@ -229,10 +227,11 @@ impl Tool for Divide { type Output = i32; async fn definition(&self, _prompt: String) -> ToolDefinition { - serde_json::from_value(json!({ - "name": "divide", - "description": "Compute the Quotient of x and y (i.e.: x / y). Useful for ratios.", - "parameters": { + ToolDefinition { + name: "divide".to_string(), + description: "Compute the Quotient of x and y (i.e.: x / y). Useful for ratios." + .to_string(), + parameters: json!({ "type": "object", "properties": { "x": { @@ -244,9 +243,8 @@ impl Tool for Divide { "description": "The Divisor of the division. The number by which the dividend is being divided" } } - } - })) - .expect("Tool Definition") + }), + } } async fn call(&self, args: Self::Args) -> Result { diff --git a/rig/rig-core/examples/request_hook.rs b/rig/rig-core/examples/request_hook.rs index d01921f77..cbd545988 100644 --- a/rig/rig-core/examples/request_hook.rs +++ b/rig/rig-core/examples/request_hook.rs @@ -51,7 +51,7 @@ where #[tokio::main] async fn main() -> Result<()> { - let agent = openai::Client::from_env() + let agent = openai::Client::from_env()? .agent(openai::GPT_4O) .preamble("You are a comedian here to entertain the user using humour and jokes.") .build(); diff --git a/rig/rig-core/examples/rmcp.rs b/rig/rig-core/examples/rmcp.rs index 492a07449..fc0150b45 100644 --- a/rig/rig-core/examples/rmcp.rs +++ b/rig/rig-core/examples/rmcp.rs @@ -262,14 +262,14 @@ async fn main() -> anyhow::Result<()> { let server_info = mcp_service.peer_info(); tracing::info!("Connected to server: {server_info:#?}"); - let openai_client = openai::Client::from_env(); + let openai_client = openai::Client::from_env()?; let agent = openai_client .agent(openai::GPT_4O) .preamble("You are a helpful assistant who has access to a number of tools from an MCP server designed to be used for incrementing and decrementing a counter.") .tool_server_handle(tool_server_handle) .build(); - let res = agent.prompt("What is 2+5?").max_turns(2).await.unwrap(); + let res = agent.prompt("What is 2+5?").max_turns(2).await?; println!("GPT-4o: {res}"); diff --git a/rig/rig-core/examples/sentiment_classifier.rs b/rig/rig-core/examples/sentiment_classifier.rs index 99a632219..777c02691 100644 --- a/rig/rig-core/examples/sentiment_classifier.rs +++ b/rig/rig-core/examples/sentiment_classifier.rs @@ -24,7 +24,7 @@ struct DocumentSentiment { #[tokio::main] async fn main() -> Result<()> { - let extractor = openai::Client::from_env() + let extractor = openai::Client::from_env()? .extractor::(openai::GPT_4) .build(); diff --git a/rig/rig-core/examples/transcription.rs b/rig/rig-core/examples/transcription.rs index a6178d0d2..fc0f67d28 100644 --- a/rig/rig-core/examples/transcription.rs +++ b/rig/rig-core/examples/transcription.rs @@ -7,98 +7,97 @@ use rig::{ use std::env::args; #[tokio::main] -async fn main() { +async fn main() -> Result<(), anyhow::Error> { let args = args().collect::>(); if args.len() <= 1 { println!("No file was specified!"); - return; + return Ok(()); } - let file_path = args[1].clone(); + let file_path = args + .get(1) + .cloned() + .ok_or_else(|| anyhow::anyhow!("No file was specified"))?; println!("Transcribing {}", &file_path); - whisper(&file_path).await; - gemini(&file_path).await; - azure(&file_path).await; - groq(&file_path).await; - huggingface(&file_path).await; - mistral(&file_path).await; + whisper(&file_path).await?; + gemini(&file_path).await?; + azure(&file_path).await?; + groq(&file_path).await?; + huggingface(&file_path).await?; + mistral(&file_path).await?; + + Ok(()) } -async fn whisper(file_path: &str) { - let openai = openai::Client::from_env(); +async fn whisper(file_path: &str) -> Result<(), anyhow::Error> { + let openai = openai::Client::from_env()?; let whisper = openai.transcription_model(openai::WHISPER_1); let response = whisper .transcription_request() - .load_file(file_path) - .expect("Failed to load file for transcription") + .load_file(file_path)? .send() - .await - .expect("Failed to transcribe file"); + .await?; println!("Whisper-1: {}", response.text); + Ok(()) } -async fn gemini(file_path: &str) { - let gemini = gemini::Client::from_env(); +async fn gemini(file_path: &str) -> Result<(), anyhow::Error> { + let gemini = gemini::Client::from_env()?; let model = gemini.transcription_model(gemini::completion::GEMINI_3_FLASH_PREVIEW); let response = model .transcription_request() - .load_file(file_path) - .expect("Failed to load file for transcription") + .load_file(file_path)? .send() - .await - .expect("Failed to transcribe file"); + .await?; println!("Gemini: {}", response.text); + Ok(()) } -async fn azure(file_path: &str) { - let azure = azure::Client::from_env(); +async fn azure(file_path: &str) -> Result<(), anyhow::Error> { + let azure = azure::Client::from_env()?; let whisper = azure.transcription_model("whisper"); let response = whisper .transcription_request() - .load_file(file_path) - .expect("Failed to load file for transcription") + .load_file(file_path)? .send() - .await - .expect("Failed to transcribe file"); + .await?; println!("Azure Whisper-1: {}", response.text); + Ok(()) } -async fn groq(file_path: &str) { - let groq = groq::Client::from_env(); +async fn groq(file_path: &str) -> Result<(), anyhow::Error> { + let groq = groq::Client::from_env()?; let whisper = groq.transcription_model(groq::WHISPER_LARGE_V3); let response = whisper .transcription_request() - .load_file(file_path) - .expect("Failed to load file for transcription") + .load_file(file_path)? .send() - .await - .expect("Failed to transcribe file"); + .await?; println!("Groq Whisper-Large-V3: {}", response.text); + Ok(()) } -async fn huggingface(file_path: &str) { - let huggingface = huggingface::Client::from_env(); +async fn huggingface(file_path: &str) -> Result<(), anyhow::Error> { + let huggingface = huggingface::Client::from_env()?; let whisper = huggingface.transcription_model("whisper-large-v3"); let response = whisper .transcription_request() - .load_file(file_path) - .expect("Failed to load file for transcription") + .load_file(file_path)? .send() - .await - .expect("Failed to transcribe file"); + .await?; println!("HuggingFace Whisper-Large-V3: {}", response.text); + Ok(()) } -async fn mistral(file_path: &str) { - let client = mistral::Client::from_env(); +async fn mistral(file_path: &str) -> Result<(), anyhow::Error> { + let client = mistral::Client::from_env()?; let model = client.transcription_model(mistral::VOXTRAL_MINI); let response = model .transcription_request() - .load_file(file_path) - .expect("Failed to load file for transcription") + .load_file(file_path)? .send() - .await - .expect("Failed to transcribe file using Mistral"); + .await?; println!("Mistral: {}", response.text); + Ok(()) } diff --git a/rig/rig-core/examples/vector_search.rs b/rig/rig-core/examples/vector_search.rs index 7af245a9f..a29e11519 100644 --- a/rig/rig-core/examples/vector_search.rs +++ b/rig/rig-core/examples/vector_search.rs @@ -70,7 +70,7 @@ fn print_id_matches(label: &str, matches: &[(f64, String)]) { #[tokio::main] async fn main() -> Result<(), anyhow::Error> { - let openai_client = Client::from_env(); + let openai_client = Client::from_env()?; let embedding_model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) .documents(sample_documents())? diff --git a/rig/rig-core/examples/vector_search_cohere.rs b/rig/rig-core/examples/vector_search_cohere.rs index fd5da5265..e6e6d59ca 100644 --- a/rig/rig-core/examples/vector_search_cohere.rs +++ b/rig/rig-core/examples/vector_search_cohere.rs @@ -63,7 +63,7 @@ fn print_matches(matches: &[SearchMatch]) { #[tokio::main] async fn main() -> Result<(), anyhow::Error> { - let cohere_client = Client::from_env(); + let cohere_client = Client::from_env()?; let document_model = cohere_client.embedding_model(cohere::EMBED_ENGLISH_V3, "search_document"); let search_model = cohere_client.embedding_model(cohere::EMBED_ENGLISH_V3, "search_query"); let embeddings = EmbeddingsBuilder::new(document_model.clone()) diff --git a/rig/rig-core/src/agent/prompt_request/mod.rs b/rig/rig-core/src/agent/prompt_request/mod.rs index 527af3f95..ab2e2c2a4 100644 --- a/rig/rig-core/src/agent/prompt_request/mod.rs +++ b/rig/rig-core/src/agent/prompt_request/mod.rs @@ -352,10 +352,13 @@ where // We need to do at least 2 loops for 1 roundtrip (user expects normal message) let last_prompt = loop { // Get the last message (the current prompt) - let prompt = new_messages - .last() - .expect("there should always be at least one message") - .clone(); + let Some((prompt_ref, history_for_current_turn)) = new_messages.split_last() else { + return Err(PromptError::prompt_cancelled( + build_full_history(chat_history.as_deref(), new_messages), + "prompt loop lost its pending prompt", + )); + }; + let prompt = prompt_ref.clone(); if current_max_turns > self.max_turns + 1 { break prompt; @@ -372,10 +375,8 @@ where } // Build history for hook callback (input + new messages except last) - let history_for_hook = build_history_for_request( - chat_history.as_deref(), - &new_messages[..new_messages.len().saturating_sub(1)], - ); + let history_for_hook = + build_history_for_request(chat_history.as_deref(), history_for_current_turn); if let Some(ref hook) = self.hook && let HookAction::Terminate { reason } = @@ -419,10 +420,8 @@ where }; // Build history for completion request (input + new messages except last) - let history_for_request = build_history_for_request( - chat_history.as_deref(), - &new_messages[..new_messages.len().saturating_sub(1)], - ); + let history_for_request = + build_history_for_request(chat_history.as_deref(), history_for_current_turn); let resp = build_completion_request( &self.model, @@ -630,9 +629,10 @@ where )) } } else { - unreachable!( - "This should never happen as we already filtered for `ToolCall`" - ) + Err(PromptError::prompt_cancelled( + Vec::new(), + "tool execution received non-tool assistant content", + )) } } .instrument(tool_span) @@ -643,9 +643,14 @@ where .into_iter() .collect::, _>>()?; - new_messages.push(Message::User { - content: OneOrMany::many(tool_content).expect("There is at least one tool call"), - }); + let Some(content) = OneOrMany::from_iter_optional(tool_content) else { + return Err(PromptError::prompt_cancelled( + build_full_history(chat_history.as_deref(), new_messages), + "tool execution produced no tool results", + )); + }; + + new_messages.push(Message::User { content }); }; // If we reach here, we exceeded max turns without a final response diff --git a/rig/rig-core/src/agent/prompt_request/streaming.rs b/rig/rig-core/src/agent/prompt_request/streaming.rs index f4b0a7e47..df07702b4 100644 --- a/rig/rig-core/src/agent/prompt_request/streaming.rs +++ b/rig/rig-core/src/agent/prompt_request/streaming.rs @@ -407,10 +407,15 @@ where // See also: https://github.com/rust-lang/rust-clippy/issues/8722 let stream = async_stream::stream! { 'outer: loop { - let current_prompt = new_messages - .last() - .cloned() - .expect("streaming loop should always have a pending prompt"); + let Some((current_prompt_ref, previous_messages)) = new_messages.split_last() else { + yield Err(cancelled_prompt_error( + chat_history.as_deref(), + new_messages.clone(), + "streaming loop lost its pending prompt".to_string(), + ).await); + break 'outer; + }; + let current_prompt = current_prompt_ref.clone(); if current_max_turns > self.max_turns + 1 { last_prompt_error = current_prompt.rag_text().unwrap_or_default(); @@ -430,7 +435,7 @@ where let history_snapshot: Vec = build_history_for_request( chat_history.as_deref(), - &new_messages[..new_messages.len().saturating_sub(1)], + previous_messages, ); if let Some(ref hook) = self.hook @@ -679,10 +684,10 @@ where content_items.extend(tool_calls.clone()); - if !content_items.is_empty() { + if let Some(content) = OneOrMany::from_iter_optional(content_items) { new_messages.push(Message::Assistant { id: stream.message_id.clone(), - content: OneOrMany::many(content_items).expect("Should have at least one item"), + content, }); } } @@ -763,14 +768,14 @@ pub async fn stream_to_stdout( Text { text }, ))) => { print!("{text}"); - std::io::Write::flush(&mut std::io::stdout()).unwrap(); + std::io::Write::flush(&mut std::io::stdout())?; } Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning( reasoning, ))) => { let reasoning = reasoning.display_text(); print!("{reasoning}"); - std::io::Write::flush(&mut std::io::stdout()).unwrap(); + std::io::Write::flush(&mut std::io::stdout())?; } Ok(MultiTurnStreamItem::FinalResponse(res)) => { final_res = res; @@ -1315,7 +1320,7 @@ mod tests { /// making the span leak deterministic (it only occurs when tasks share a thread). #[tokio::test(flavor = "current_thread")] #[ignore = "This requires an API key"] - async fn test_span_context_isolation() { + async fn test_span_context_isolation() -> anyhow::Result<()> { let stop = Arc::new(AtomicBool::new(false)); let leak_count = Arc::new(AtomicU32::new(0)); @@ -1331,7 +1336,7 @@ mod tests { // Make streaming request WITHOUT an outer span so rig creates its own invoke_agent span // (rig reuses current span if one exists, so we need to ensure there's no current span) - let client = anthropic::Client::from_env(); + let client = anthropic::Client::from_env()?; let agent = client .agent(anthropic::completion::CLAUDE_HAIKU_4_5) .preamble("You are a helpful assistant.") @@ -1366,7 +1371,7 @@ mod tests { // Stop background logger stop.store(true, Ordering::Relaxed); - bg_handle.await.unwrap(); + bg_handle.await?; let leaks = leak_count.load(Ordering::Relaxed); assert_eq!( @@ -1374,6 +1379,8 @@ mod tests { "SPAN LEAK DETECTED: Background logger was inside unexpected spans {leaks} times. \ This indicates that span.enter() is being used inside async_stream instead of .instrument()" ); + + Ok(()) } /// Test that FinalResponse contains the updated chat history when with_history is used. @@ -1383,10 +1390,10 @@ mod tests { /// 2. The history contains both the user prompt and assistant response #[tokio::test] #[ignore = "This requires an API key"] - async fn test_chat_history_in_final_response() { + async fn test_chat_history_in_final_response() -> anyhow::Result<()> { use crate::message::Message; - let client = anthropic::Client::from_env(); + let client = anthropic::Client::from_env()?; let agent = client .agent(anthropic::completion::CLAUDE_HAIKU_4_5) .preamble("You are a helpful assistant. Keep responses brief.") @@ -1416,14 +1423,14 @@ mod tests { break; } Err(e) => { - panic!("Streaming error: {:?}", e); + return Err(e.into()); } _ => {} } } - let history = - final_history.expect("FinalResponse should contain history when with_history is used"); + let history = final_history + .ok_or_else(|| anyhow::anyhow!("final response should include history"))?; // Should contain at least the user message assert!( @@ -1444,5 +1451,7 @@ mod tests { history.len(), response_text ); + + Ok(()) } } diff --git a/rig/rig-core/src/agent/tool.rs b/rig/rig-core/src/agent/tool.rs index de404e497..150c72a4c 100644 --- a/rig/rig-core/src/agent/tool.rs +++ b/rig/rig-core/src/agent/tool.rs @@ -5,6 +5,7 @@ use crate::{ }; use schemars::{JsonSchema, schema_for}; use serde::{Deserialize, Serialize}; +use serde_json::json; #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct AgentToolArgs { @@ -35,8 +36,7 @@ impl Tool for Agent { ToolDefinition { name: ::name(self), description, - parameters: serde_json::to_value(schema_for!(AgentToolArgs)) - .expect("converting JSON schema to JSON value should never fail"), + parameters: json!(schema_for!(AgentToolArgs)), } } diff --git a/rig/rig-core/src/client/mod.rs b/rig/rig-core/src/client/mod.rs index 404eb750d..1baf8b757 100644 --- a/rig/rig-core/src/client/mod.rs +++ b/rig/rig-core/src/client/mod.rs @@ -14,7 +14,7 @@ pub use completion::CompletionClient; pub use embeddings::EmbeddingsClient; use http::{HeaderMap, HeaderName, HeaderValue}; pub use model_listing::{ModelLister, ModelListingClient}; -use std::{fmt::Debug, marker::PhantomData, sync::Arc}; +use std::{env::VarError, fmt::Debug, marker::PhantomData, sync::Arc}; use thiserror::Error; pub use verify::{VerifyClient, VerifyError}; @@ -53,16 +53,49 @@ pub enum ClientBuilderError { InvalidProperty(&'static str), } +#[derive(Debug, Error)] +#[non_exhaustive] +pub enum ProviderClientError { + #[error("environment variable `{name}` is not set or is invalid")] + EnvironmentVariable { + name: &'static str, + #[source] + source: VarError, + }, + #[error(transparent)] + Http(#[from] http_client::Error), + #[error("{0}")] + InvalidConfiguration(&'static str), +} + +pub type ProviderClientResult = std::result::Result; + +pub fn required_env_var(name: &'static str) -> ProviderClientResult { + std::env::var(name).map_err(|source| ProviderClientError::EnvironmentVariable { name, source }) +} + +pub fn optional_env_var(name: &'static str) -> ProviderClientResult> { + match std::env::var(name) { + Ok(value) => Ok(Some(value)), + Err(VarError::NotPresent) => Ok(None), + Err(source) => Err(ProviderClientError::EnvironmentVariable { name, source }), + } +} + /// Abstracts over the ability to instantiate a client, either via environment variables or some /// `Self::Input` pub trait ProviderClient { type Input; + type Error; /// Create a client from the process's environment. - /// Panics if an environment is improperly configured. - fn from_env() -> Self; + fn from_env() -> Result + where + Self: Sized; - fn from_val(input: Self::Input) -> Self; + fn from_val(input: Self::Input) -> Result + where + Self: Sized; } /// A trait for API keys. This determines whether the key is inserted into a [Client]'s default @@ -316,7 +349,7 @@ where mut req: Request, ) -> impl Future> + WasmCompatSend where - T: Into, + T: Into + WasmCompatSend, { req.headers_mut().insert( http::header::CONTENT_TYPE, @@ -750,7 +783,7 @@ mod wasm_model_listing_compile_checks { _req: Request, ) -> impl Future> + WasmCompatSend where - T: Into, + T: Into + WasmCompatSend, { future::ready(Err(http_client::Error::StreamEnded)) } diff --git a/rig/rig-core/src/completion/request.rs b/rig/rig-core/src/completion/request.rs index 1dfd07690..0a6ef4614 100644 --- a/rig/rig-core/src/completion/request.rs +++ b/rig/rig-core/src/completion/request.rs @@ -567,9 +567,7 @@ impl CompletionRequest { }) .collect::>(); - Some(Message::User { - content: OneOrMany::many(messages).expect("There will be atleast one document"), - }) + OneOrMany::from_iter_optional(messages).map(|content| Message::User { content }) } /// Adds a provider-hosted tool by storing it in `additional_params.tools`. @@ -863,13 +861,14 @@ impl CompletionRequestBuilder { pub fn build(self) -> CompletionRequest { // Build the final message list, prepending preamble if present let mut chat_history = self.chat_history; + let prompt = self.prompt; if let Some(preamble) = self.preamble { chat_history.insert(0, Message::system(preamble)); } - chat_history.push(self.prompt); + chat_history.push(prompt.clone()); let chat_history = - OneOrMany::many(chat_history).expect("There will always be at least the prompt"); + OneOrMany::from_iter_optional(chat_history).unwrap_or_else(|| OneOrMany::one(prompt)); let additional_params = merge_provider_tools_into_additional_params( self.additional_params, self.provider_tools, diff --git a/rig/rig-core/src/embeddings/builder.rs b/rig/rig-core/src/embeddings/builder.rs index b56c29af1..fe531fe98 100644 --- a/rig/rig-core/src/embeddings/builder.rs +++ b/rig/rig-core/src/embeddings/builder.rs @@ -142,15 +142,16 @@ where .await?; // Merge the embeddings with their respective documents - Ok(docs - .into_iter() + docs.into_iter() .map(|(i, doc)| { - ( - doc, - embeddings.remove(&i).expect("Document should be present"), - ) + let embedding = embeddings.remove(&i).ok_or_else(|| { + crate::embeddings::EmbeddingError::ResponseError( + "missing embedding for document after batch merge".to_string(), + ) + })?; + Ok::<_, crate::embeddings::EmbeddingError>((doc, embedding)) }) - .collect()) + .collect::, crate::embeddings::EmbeddingError>>() } } diff --git a/rig/rig-core/src/embeddings/embedding.rs b/rig/rig-core/src/embeddings/embedding.rs index 26350acdb..95cae3af4 100644 --- a/rig/rig-core/src/embeddings/embedding.rs +++ b/rig/rig-core/src/embeddings/embedding.rs @@ -68,11 +68,12 @@ pub trait EmbeddingModel: WasmCompatSend + WasmCompatSync { text: &str, ) -> impl std::future::Future> + WasmCompatSend { async { - Ok(self - .embed_texts(vec![text.to_string()]) - .await? - .pop() - .expect("There should be at least one embedding")) + let mut embeddings = self.embed_texts(vec![text.to_string()]).await?; + embeddings.pop().ok_or_else(|| { + EmbeddingError::ResponseError( + "embedding provider returned an empty response for embed_text".to_string(), + ) + }) } } } @@ -97,11 +98,12 @@ pub trait ImageEmbeddingModel: Clone + WasmCompatSend + WasmCompatSync { bytes: &'a [u8], ) -> impl std::future::Future> + WasmCompatSend { async move { - Ok(self - .embed_images(vec![bytes.to_owned()]) - .await? - .pop() - .expect("There should be at least one embedding")) + let mut embeddings = self.embed_images(vec![bytes.to_owned()]).await?; + embeddings.pop().ok_or_else(|| { + EmbeddingError::ResponseError( + "embedding provider returned an empty response for embed_image".to_string(), + ) + }) } } } diff --git a/rig/rig-core/src/http_client/mod.rs b/rig/rig-core/src/http_client/mod.rs index 1bd895409..782f3b49f 100644 --- a/rig/rig-core/src/http_client/mod.rs +++ b/rig/rig-core/src/http_client/mod.rs @@ -48,6 +48,15 @@ fn instance_error(error: E) -> Error { Error::Instance(error.into()) } +async fn non_success_status_error(response: reqwest::Response) -> Error { + let status = response.status(); + let message = response + .text() + .await + .unwrap_or_else(|error| format!("failed to read error response body: {error}")); + Error::InvalidStatusCodeWithMessage(status, message) +} + pub type LazyBytes = WasmBoxedFuture<'static, Result>; pub type LazyBody = WasmBoxedFuture<'static, Result>; @@ -122,7 +131,7 @@ pub trait HttpClientExt: WasmCompatSend + WasmCompatSync { req: Request, ) -> impl Future> + WasmCompatSend where - T: Into; + T: Into + WasmCompatSend; } impl HttpClientExt for reqwest::Client { @@ -143,10 +152,7 @@ impl HttpClientExt for reqwest::Client { async move { let response = req.send().await.map_err(instance_error)?; if !response.status().is_success() { - return Err(Error::InvalidStatusCodeWithMessage( - response.status(), - response.text().await.unwrap(), - )); + return Err(non_success_status_error(response).await); } let mut res = Response::builder().status(response.status()); @@ -188,10 +194,7 @@ impl HttpClientExt for reqwest::Client { async move { let response = req.send().await.map_err(instance_error)?; if !response.status().is_success() { - return Err(Error::InvalidStatusCodeWithMessage( - response.status(), - response.text().await.unwrap(), - )); + return Err(non_success_status_error(response).await); } let mut res = Response::builder().status(response.status()); @@ -219,27 +222,22 @@ impl HttpClientExt for reqwest::Client { req: Request, ) -> impl Future> + WasmCompatSend where - T: Into, + T: Into + WasmCompatSend, { let (parts, body) = req.into_parts(); - let req = self - .request(parts.method, parts.uri.to_string()) - .headers(parts.headers) - .body(body.into()) - .build() - .map_err(|x| Error::Instance(x.into())) - .unwrap(); - let client = self.clone(); async move { + let req = self + .request(parts.method, parts.uri.to_string()) + .headers(parts.headers) + .body(body.into()) + .build() + .map_err(|error| Error::Instance(error.into()))?; let response: reqwest::Response = client.execute(req).await.map_err(instance_error)?; if !response.status().is_success() { - return Err(Error::InvalidStatusCodeWithMessage( - response.status(), - response.text().await.unwrap(), - )); + return Err(non_success_status_error(response).await); } #[cfg(not(target_family = "wasm"))] @@ -288,10 +286,7 @@ impl HttpClientExt for reqwest_middleware::ClientWithMiddleware { async move { let response = req.send().await.map_err(instance_error)?; if !response.status().is_success() { - return Err(Error::InvalidStatusCodeWithMessage( - response.status(), - response.text().await.unwrap(), - )); + return Err(non_success_status_error(response).await); } let mut res = Response::builder().status(response.status()); @@ -333,10 +328,7 @@ impl HttpClientExt for reqwest_middleware::ClientWithMiddleware { async move { let response = req.send().await.map_err(instance_error)?; if !response.status().is_success() { - return Err(Error::InvalidStatusCodeWithMessage( - response.status(), - response.text().await.unwrap(), - )); + return Err(non_success_status_error(response).await); } let mut res = Response::builder().status(response.status()); @@ -364,27 +356,22 @@ impl HttpClientExt for reqwest_middleware::ClientWithMiddleware { req: Request, ) -> impl Future> + WasmCompatSend where - T: Into, + T: Into + WasmCompatSend, { let (parts, body) = req.into_parts(); - let req = self - .request(parts.method, parts.uri.to_string()) - .headers(parts.headers) - .body(body.into()) - .build() - .map_err(|x| Error::Instance(x.into())) - .unwrap(); - let client = self.clone(); async move { + let req = self + .request(parts.method, parts.uri.to_string()) + .headers(parts.headers) + .body(body.into()) + .build() + .map_err(|error| Error::Instance(error.into()))?; let response: reqwest::Response = client.execute(req).await.map_err(instance_error)?; if !response.status().is_success() { - return Err(Error::InvalidStatusCodeWithMessage( - response.status(), - response.text().await.unwrap(), - )); + return Err(non_success_status_error(response).await); } #[cfg(not(target_family = "wasm"))] @@ -461,7 +448,7 @@ pub(crate) mod mock { _req: Request, ) -> impl Future> + WasmCompatSend where - T: Into, + T: Into + WasmCompatSend, { let sse_bytes = self.sse_bytes.clone(); async move { diff --git a/rig/rig-core/src/http_client/multipart.rs b/rig/rig-core/src/http_client/multipart.rs index 78731f686..af1a759d8 100644 --- a/rig/rig-core/src/http_client/multipart.rs +++ b/rig/rig-core/src/http_client/multipart.rs @@ -121,7 +121,7 @@ impl MultipartForm { use std::time::{SystemTime, UNIX_EPOCH}; let timestamp = SystemTime::now() .duration_since(UNIX_EPOCH) - .unwrap() + .unwrap_or_default() .as_nanos(); format!("----boundary{}", timestamp) } @@ -193,14 +193,17 @@ impl From for reqwest::multipart::Form { form = form.text(part.name, text); } PartContent::Binary(bytes) => { - let mut req_part = reqwest::multipart::Part::bytes(bytes.to_vec()); + let mut req_part = if let Some(content_type) = part.content_type.as_ref() { + reqwest::multipart::Part::bytes(bytes.to_vec()) + .mime_str(content_type.as_ref()) + .unwrap_or_else(|_| reqwest::multipart::Part::bytes(bytes.to_vec())) + } else { + reqwest::multipart::Part::bytes(bytes.to_vec()) + }; if let Some(filename) = part.filename { req_part = req_part.file_name(filename); } - if let Some(content_type) = part.content_type { - req_part = req_part.mime_str(content_type.as_ref()).unwrap(); - } form = form.part(part.name, req_part); } diff --git a/rig/rig-core/src/http_client/retry.rs b/rig/rig-core/src/http_client/retry.rs index ddce1eb44..d1c6c902e 100644 --- a/rig/rig-core/src/http_client/retry.rs +++ b/rig/rig-core/src/http_client/retry.rs @@ -45,7 +45,10 @@ impl ExponentialBackoff { impl RetryPolicy for ExponentialBackoff { fn retry(&self, _error: &Error, last_retry: Option<(usize, Duration)>) -> Option { if let Some((retry_num, last_duration)) = last_retry { - if self.max_retries.is_none() || retry_num < self.max_retries.unwrap() { + if self + .max_retries + .is_none_or(|max_retries| retry_num < max_retries) + { let duration = last_duration.mul_f64(self.factor); if let Some(max_duration) = self.max_duration { Some(duration.min(max_duration)) @@ -86,7 +89,10 @@ impl Constant { impl RetryPolicy for Constant { fn retry(&self, _error: &Error, last_retry: Option<(usize, Duration)>) -> Option { if let Some((retry_num, _)) = last_retry { - if self.max_retries.is_none() || retry_num < self.max_retries.unwrap() { + if self + .max_retries + .is_none_or(|max_retries| retry_num < max_retries) + { Some(self.delay) } else { None diff --git a/rig/rig-core/src/integrations/cli_chatbot.rs b/rig/rig-core/src/integrations/cli_chatbot.rs index b049d9f1e..6ee841ea6 100644 --- a/rig/rig-core/src/integrations/cli_chatbot.rs +++ b/rig/rig-core/src/integrations/cli_chatbot.rs @@ -182,7 +182,11 @@ where loop { print!("> "); - stdout.flush().unwrap(); + stdout.flush().map_err(|e| { + PromptError::CompletionError(CompletionError::ResponseError(format!( + "failed to flush stdout: {e}" + ))) + })?; let mut input = String::new(); match stdin.read_line(&mut input) { @@ -204,12 +208,13 @@ where println!("================================================================"); println!(); - if self.0.show_usage() { - let Usage { + if self.0.show_usage() + && let Some(Usage { input_tokens, output_tokens, .. - } = self.0.usage().unwrap(); + }) = self.0.usage() + { println!("Input {input_tokens} tokens\nOutput {output_tokens} tokens"); } } diff --git a/rig/rig-core/src/integrations/discord_bot.rs b/rig/rig-core/src/integrations/discord_bot.rs index 6c1396a29..4ee8ec594 100644 --- a/rig/rig-core/src/integrations/discord_bot.rs +++ b/rig/rig-core/src/integrations/discord_bot.rs @@ -9,9 +9,19 @@ use serenity::all::{ GatewayIntents, Interaction, Message, Ready, async_trait, }; use std::collections::HashMap; +use std::env; use std::sync::Arc; +use thiserror::Error; use tokio::sync::RwLock; +#[derive(Debug, Error)] +pub enum DiscordBotError { + #[error("Discord bot token missing from environment: {0}")] + MissingToken(#[from] env::VarError), + #[error("Failed to build Discord client: {0}")] + ClientBuild(#[from] serenity::Error), +} + // Bot state containing the agent and conversation histories struct BotState { agent: Agent, @@ -223,15 +233,15 @@ where fn into_discord_bot( self, token: &str, - ) -> impl std::future::Future + Send; + ) -> impl std::future::Future> + Send; fn into_discord_bot_from_env( self, - ) -> impl std::future::Future + Send { - let token = std::env::var("DISCORD_BOT_TOKEN") - .expect("DISCORD_BOT_TOKEN should exist as an env var"); - - async move { DiscordExt::into_discord_bot(self, &token).await } + ) -> impl std::future::Future> + Send { + async move { + let token = std::env::var("DISCORD_BOT_TOKEN")?; + DiscordExt::into_discord_bot(self, &token).await + } } } @@ -239,7 +249,7 @@ impl DiscordExt for Agent where M: CompletionModel + Send + Sync + 'static, { - async fn into_discord_bot(self, token: &str) -> serenity::Client { + async fn into_discord_bot(self, token: &str) -> Result { let intents = GatewayIntents::GUILDS | GatewayIntents::GUILD_MESSAGES | GatewayIntents::MESSAGE_CONTENT; @@ -252,6 +262,6 @@ where serenity::Client::builder(token, intents) .event_handler(handler) .await - .expect("Failed to create Discord client") + .map_err(DiscordBotError::from) } } diff --git a/rig/rig-core/src/lib.rs b/rig/rig-core/src/lib.rs index 19d6797ca..a38d9aad2 100644 --- a/rig/rig-core/src/lib.rs +++ b/rig/rig-core/src/lib.rs @@ -1,4 +1,14 @@ #![cfg_attr(docsrs, feature(doc_cfg))] +#![cfg_attr( + test, + allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable + ) +)] //! Rig is a Rust library for building LLM-powered applications that focuses on ergonomics and modularity. //! //! # Table of contents diff --git a/rig/rig-core/src/model/listing.rs b/rig/rig-core/src/model/listing.rs index 209d1fd09..b43ae66a7 100644 --- a/rig/rig-core/src/model/listing.rs +++ b/rig/rig-core/src/model/listing.rs @@ -338,7 +338,8 @@ const RESPONSE_BODY_PREVIEW_LIMIT: usize = 2048; fn format_response_body_preview(body: &[u8]) -> String { let preview_len = body.len().min(RESPONSE_BODY_PREVIEW_LIMIT); - let mut preview = String::from_utf8_lossy(&body[..preview_len]).into_owned(); + let preview_bytes = body.get(..preview_len).unwrap_or(body); + let mut preview = String::from_utf8_lossy(preview_bytes).into_owned(); if body.len() > RESPONSE_BODY_PREVIEW_LIMIT { preview.push_str(&format!( diff --git a/rig/rig-core/src/one_or_many.rs b/rig/rig-core/src/one_or_many.rs index ea9bc27bb..8dbe4d7dd 100644 --- a/rig/rig-core/src/one_or_many.rs +++ b/rig/rig-core/src/one_or_many.rs @@ -137,6 +137,19 @@ impl OneOrMany { } } + /// Build a `OneOrMany` from an iterator when the caller can naturally handle an empty input. + pub(crate) fn from_iter_optional(items: I) -> Option + where + I: IntoIterator, + { + let mut iter = items.into_iter(); + let first = iter.next()?; + Some(OneOrMany { + first, + rest: iter.collect(), + }) + } + /// Specialized try map function for OneOrMany objects. /// /// Same as `OneOrMany::map` but fallible. diff --git a/rig/rig-core/src/providers/anthropic/client.rs b/rig/rig-core/src/providers/anthropic/client.rs index e5331b2ce..b0b97fffa 100644 --- a/rig/rig-core/src/providers/anthropic/client.rs +++ b/rig/rig-core/src/providers/anthropic/client.rs @@ -105,21 +105,22 @@ impl DebugExt for AnthropicExt {} impl ProviderClient for Client { type Input = String; + type Error = crate::client::ProviderClientError; - fn from_env() -> Self + fn from_env() -> Result where Self: Sized, { - let key = std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"); + let key = crate::client::required_env_var("ANTHROPIC_API_KEY")?; - Self::builder().api_key(key).build().unwrap() + Self::builder().api_key(key).build().map_err(Into::into) } - fn from_val(input: Self::Input) -> Self + fn from_val(input: Self::Input) -> Result where Self: Sized, { - Self::builder().api_key(input).build().unwrap() + Self::builder().api_key(input).build().map_err(Into::into) } } diff --git a/rig/rig-core/src/providers/anthropic/completion.rs b/rig/rig-core/src/providers/anthropic/completion.rs index e2ed933f1..aa34dea4e 100644 --- a/rig/rig-core/src/providers/anthropic/completion.rs +++ b/rig/rig-core/src/providers/anthropic/completion.rs @@ -1030,9 +1030,13 @@ impl TryFrom for ToolChoice { )); } - Self::Tool { - name: function_names.first().unwrap().to_string(), - } + let Some(name) = function_names.into_iter().next() else { + return Err(CompletionError::ProviderError( + "Only one tool may be specified to be used by Claude".into(), + )); + }; + + Self::Tool { name } } }; diff --git a/rig/rig-core/src/providers/anthropic/decoders/line.rs b/rig/rig-core/src/providers/anthropic/decoders/line.rs index 3922eaa47..ad810f1d9 100644 --- a/rig/rig-core/src/providers/anthropic/decoders/line.rs +++ b/rig/rig-core/src/providers/anthropic/decoders/line.rs @@ -48,7 +48,9 @@ impl LineDecoder { if let Some(cr_index) = self.carriage_return_index { if pattern_index.index != cr_index + 1 || pattern_index.carriage { if cr_index > 0 { - let line = decode_text(&self.buffer[0..cr_index - 1]); + let line = decode_text( + self.buffer.get(..cr_index.saturating_sub(1)).unwrap_or(&[]), + ); lines.push(line); } else { // Handle edge case for carriage return at beginning @@ -56,7 +58,7 @@ impl LineDecoder { } if cr_index < self.buffer.len() { - self.buffer = self.buffer[cr_index..].to_vec(); + self.buffer = self.buffer.get(cr_index..).unwrap_or(&[]).to_vec(); } else { self.buffer.clear(); } @@ -72,14 +74,18 @@ impl LineDecoder { }; if end_index > 0 { - let line = decode_text(&self.buffer[0..end_index]); + let line = decode_text(self.buffer.get(..end_index).unwrap_or(&[])); lines.push(line); } else { lines.push(String::new()); } if pattern_index.index < self.buffer.len() { - self.buffer = self.buffer[pattern_index.index..].to_vec(); + self.buffer = self + .buffer + .get(pattern_index.index..) + .unwrap_or(&[]) + .to_vec(); } else { self.buffer.clear(); } @@ -138,24 +144,18 @@ pub fn find_double_newline_index(buffer: &[u8]) -> isize { const NEWLINE: u8 = 0x0a; // \n const CARRIAGE: u8 = 0x0d; // \r - for i in 0..buffer.len().saturating_sub(1) { - // Check for \n\n pattern - if buffer[i] == NEWLINE && buffer[i + 1] == NEWLINE { + for (i, window) in buffer.windows(2).enumerate() { + if window == [NEWLINE, NEWLINE] { return (i + 2) as isize; } - // Check for \r\r pattern - if buffer[i] == CARRIAGE && buffer[i + 1] == CARRIAGE { + if window == [CARRIAGE, CARRIAGE] { return (i + 2) as isize; } + } - // Check for \r\n\r\n pattern - if i + 3 < buffer.len() - && buffer[i] == CARRIAGE - && buffer[i + 1] == NEWLINE - && buffer[i + 2] == CARRIAGE - && buffer[i + 3] == NEWLINE - { + for (i, window) in buffer.windows(4).enumerate() { + if window == [CARRIAGE, NEWLINE, CARRIAGE, NEWLINE] { return (i + 4) as isize; } } diff --git a/rig/rig-core/src/providers/anthropic/decoders/sse.rs b/rig/rig-core/src/providers/anthropic/decoders/sse.rs index 197600304..856a99ff6 100644 --- a/rig/rig-core/src/providers/anthropic/decoders/sse.rs +++ b/rig/rig-core/src/providers/anthropic/decoders/sse.rs @@ -94,12 +94,9 @@ impl SSEDecoder { } // Parse field:value format - let parts: Vec<&str> = line.splitn(2, ':').collect(); - let (field_name, value) = match parts.as_slice() { - [field] => (*field, ""), - [field, value] => (*field, *value), - _ => unreachable!(), - }; + let (field_name, value) = line + .split_once(':') + .map_or((line.as_str(), ""), |(field, value)| (field, value)); // Trim leading space from value as per SSE spec let value = if let Some(stripped) = value.strip_prefix(' ') { @@ -183,8 +180,8 @@ fn extract_sse_chunk(buffer: &[u8]) -> Option<(Vec, Vec)> { } let pattern_index = pattern_index as usize; - let chunk = buffer[0..pattern_index].to_vec(); - let remaining = buffer[pattern_index..].to_vec(); + let chunk = buffer.get(..pattern_index)?.to_vec(); + let remaining = buffer.get(pattern_index..)?.to_vec(); Some((chunk, remaining)) } diff --git a/rig/rig-core/src/providers/anthropic/streaming.rs b/rig/rig-core/src/providers/anthropic/streaming.rs index b02f69789..e9d2e6bb0 100644 --- a/rig/rig-core/src/providers/anthropic/streaming.rs +++ b/rig/rig-core/src/providers/anthropic/streaming.rs @@ -324,7 +324,7 @@ where // Anthropic streaming API spec — use them directly. let usage = PartialUsage { output_tokens: usage.output_tokens, - input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")), + input_tokens: usize::try_from(input_tokens).ok(), cache_creation_input_tokens: usage.cache_creation_input_tokens, cache_read_input_tokens: usage.cache_read_input_tokens }; diff --git a/rig/rig-core/src/providers/azure.rs b/rig/rig-core/src/providers/azure.rs index 9573d6f1a..09dd0b93e 100644 --- a/rig/rig-core/src/providers/azure.rs +++ b/rig/rig-core/src/providers/azure.rs @@ -293,26 +293,29 @@ pub struct AzureOpenAIClientParams { impl ProviderClient for Client { type Input = AzureOpenAIClientParams; + type Error = crate::client::ProviderClientError; /// Create a new Azure OpenAI client from the `AZURE_API_KEY` or `AZURE_TOKEN`, `AZURE_API_VERSION`, and `AZURE_ENDPOINT` environment variables. - fn from_env() -> Self { - let auth = if let Ok(api_key) = std::env::var("AZURE_API_KEY") { + fn from_env() -> Result { + let auth = if let Some(api_key) = crate::client::optional_env_var("AZURE_API_KEY")? { AzureOpenAIAuth::ApiKey(api_key) - } else if let Ok(token) = std::env::var("AZURE_TOKEN") { + } else if let Some(token) = crate::client::optional_env_var("AZURE_TOKEN")? { AzureOpenAIAuth::Token(token) } else { - panic!("Neither AZURE_API_KEY nor AZURE_TOKEN is set"); + return Err(crate::client::ProviderClientError::InvalidConfiguration( + "either `AZURE_API_KEY` or `AZURE_TOKEN` must be set", + )); }; - let api_version = std::env::var("AZURE_API_VERSION").expect("AZURE_API_VERSION not set"); - let azure_endpoint = std::env::var("AZURE_ENDPOINT").expect("AZURE_ENDPOINT not set"); + let api_version = crate::client::required_env_var("AZURE_API_VERSION")?; + let azure_endpoint = crate::client::required_env_var("AZURE_ENDPOINT")?; Self::builder() .api_key(auth) .azure_endpoint(azure_endpoint) .api_version(&api_version) .build() - .unwrap() + .map_err(Into::into) } fn from_val( @@ -321,7 +324,7 @@ impl ProviderClient for Client { version, header, }: Self::Input, - ) -> Self { + ) -> Result { let auth = AzureOpenAIAuth::ApiKey(api_key.to_string()); Self::builder() @@ -329,7 +332,7 @@ impl ProviderClient for Client { .azure_endpoint(header) .api_version(&version) .build() - .unwrap() + .map_err(Into::into) } } @@ -455,8 +458,12 @@ where "input": documents, }); + let body_object = body.as_object_mut().ok_or_else(|| { + EmbeddingError::ResponseError("embedding request body must be a JSON object".into()) + })?; + if self.ndims > 0 && self.model.as_str() != TEXT_EMBEDDING_ADA_002 { - body["dimensions"] = json!(self.ndims); + body_object.insert("dimensions".to_owned(), json!(self.ndims)); } let body = serde_json::to_vec(&body)?; @@ -868,10 +875,14 @@ where } if let Some(ref additional_params) = request.additional_params { - for (key, value) in additional_params - .as_object() - .expect("Additional Parameters to OpenAI Transcription should be a map") - { + let params = additional_params.as_object().ok_or_else(|| { + TranscriptionError::RequestError(Box::new(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "additional transcription parameters must be a JSON object", + ))) + })?; + + for (key, value) in params { body = body.text(key.to_owned(), value.to_string()); } } @@ -1075,40 +1086,39 @@ mod azure_tests { #[tokio::test] #[ignore] - async fn test_azure_embedding() { + async fn test_azure_embedding() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let client = Client::from_env(); + let client = Client::from_env()?; let model = client.embedding_model(TEXT_EMBEDDING_3_SMALL); - let embeddings = model - .embed_texts(vec!["Hello, world!".to_string()]) - .await - .unwrap(); + let embeddings = model.embed_texts(vec!["Hello, world!".to_string()]).await?; tracing::info!("Azure embedding: {:?}", embeddings); + Ok(()) } #[tokio::test] #[ignore] - async fn test_azure_embedding_dimensions() { + async fn test_azure_embedding_dimensions() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); let ndims = 256; - let client = Client::from_env(); + let client = Client::from_env()?; let model = client.embedding_model_with_ndims(TEXT_EMBEDDING_3_SMALL, ndims); - let embedding = model.embed_text("Hello, world!").await.unwrap(); + let embedding = model.embed_text("Hello, world!").await?; assert!(embedding.vec.len() == ndims); tracing::info!("Azure dimensions embedding: {:?}", embedding); + Ok(()) } #[tokio::test] #[ignore] - async fn test_azure_completion() { + async fn test_azure_completion() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let client = Client::from_env(); + let client = Client::from_env()?; let model = client.completion_model(GPT_4O_MINI); let completion = model .completion(CompletionRequest { @@ -1123,15 +1133,15 @@ mod azure_tests { additional_params: None, output_schema: None, }) - .await - .unwrap(); + .await?; tracing::info!("Azure completion: {:?}", completion); + Ok(()) } #[tokio::test] #[ignore] - async fn test_azure_structured_output() { + async fn test_azure_structured_output() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); #[derive(Debug, Deserialize, JsonSchema)] @@ -1140,7 +1150,7 @@ mod azure_tests { age: u32, } - let client = Client::from_env(); + let client = Client::from_env()?; let agent = client .agent(GPT_5_MINI) .preamble("You are a helpful assistant that extracts personal details.") @@ -1150,13 +1160,13 @@ mod azure_tests { let result: Person = agent .prompt_typed("Hello! My name is John Doe and I'm 54 years old.") - .await - .expect("failed to extract person"); + .await?; assert!(result.name == "John Doe"); assert!(result.age == 54); tracing::info!("Extracted person: {:?}", result); + Ok(()) } #[tokio::test] diff --git a/rig/rig-core/src/providers/chatgpt/mod.rs b/rig/rig-core/src/providers/chatgpt/mod.rs index b21782131..ef1bc105c 100644 --- a/rig/rig-core/src/providers/chatgpt/mod.rs +++ b/rig/rig-core/src/providers/chatgpt/mod.rs @@ -209,32 +209,33 @@ impl ProviderBuilder for ChatGPTBuilder { impl ProviderClient for Client { type Input = ChatGPTAuth; + type Error = crate::client::ProviderClientError; - fn from_env() -> Self { + fn from_env() -> Result { let mut builder = Self::builder(); - if let Ok(base_url) = - std::env::var("CHATGPT_API_BASE").or_else(|_| std::env::var("OPENAI_CHATGPT_API_BASE")) + if let Some(base_url) = crate::client::optional_env_var("CHATGPT_API_BASE")? + .or(crate::client::optional_env_var("OPENAI_CHATGPT_API_BASE")?) { builder = builder.base_url(base_url); } - if let Ok(access_token) = std::env::var("CHATGPT_ACCESS_TOKEN") { - let account_id = std::env::var("CHATGPT_ACCOUNT_ID").ok(); + if let Some(access_token) = crate::client::optional_env_var("CHATGPT_ACCESS_TOKEN")? { + let account_id = crate::client::optional_env_var("CHATGPT_ACCOUNT_ID")?; builder .api_key(ChatGPTAuth::AccessToken { access_token, account_id, }) .build() - .unwrap() + .map_err(Into::into) } else { - builder.oauth().build().unwrap() + builder.oauth().build().map_err(Into::into) } } - fn from_val(input: Self::Input) -> Self { - Self::builder().api_key(input).build().unwrap() + fn from_val(input: Self::Input) -> Result { + Self::builder().api_key(input).build().map_err(Into::into) } } diff --git a/rig/rig-core/src/providers/cohere/client.rs b/rig/rig-core/src/providers/cohere/client.rs index dca8030ce..234a0791d 100644 --- a/rig/rig-core/src/providers/cohere/client.rs +++ b/rig/rig-core/src/providers/cohere/client.rs @@ -67,20 +67,21 @@ impl ProviderBuilder for CohereBuilder { impl ProviderClient for Client { type Input = CohereApiKey; + type Error = crate::client::ProviderClientError; - fn from_env() -> Self + fn from_env() -> Result where Self: Sized, { - let key = std::env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set"); - Self::new(key).unwrap() + let key = crate::client::required_env_var("COHERE_API_KEY")?; + Self::new(key).map_err(Into::into) } - fn from_val(input: Self::Input) -> Self + fn from_val(input: Self::Input) -> Result where Self: Sized, { - Self::new(input).unwrap() + Self::new(input).map_err(Into::into) } } diff --git a/rig/rig-core/src/providers/cohere/completion.rs b/rig/rig-core/src/providers/cohere/completion.rs index a54270dd6..6bce94c9a 100644 --- a/rig/rig-core/src/providers/cohere/completion.rs +++ b/rig/rig-core/src/providers/cohere/completion.rs @@ -23,9 +23,11 @@ pub struct CompletionResponse { pub usage: Option, } +type AssistantMessageParts = (Vec, Vec, Vec); + impl CompletionResponse { /// Return that parts of the response for assistant messages w/o dealing with the other variants - pub fn message(&self) -> (Vec, Vec, Vec) { + pub fn message(&self) -> Result { let Message::Assistant { content, citations, @@ -33,10 +35,12 @@ impl CompletionResponse { .. } = self.message.clone() else { - unreachable!("Completion responses will only return an assistant message") + return Err(CompletionError::ResponseError( + "completion response did not contain an assistant message".into(), + )); }; - (content, citations, tool_calls) + Ok((content, citations, tool_calls)) } } @@ -137,7 +141,7 @@ impl TryFrom for completion::CompletionResponse Result { - let (content, _, tool_calls) = response.message(); + let (content, _, tool_calls) = response.message()?; let model_response = if !tool_calls.is_empty() { OneOrMany::many( @@ -151,7 +155,12 @@ impl TryFrom for completion::CompletionResponse>(), ) - .expect("We have atleast 1 tool call in this if block") + .map_err(|_| { + CompletionError::ResponseError( + "response contained tool call metadata without any callable tool content" + .to_owned(), + ) + })? } else { OneOrMany::many(content.into_iter().map(|content| match content { AssistantContent::Text { text } => completion::AssistantContent::text(text), @@ -185,7 +194,7 @@ impl TryFrom for completion::CompletionResponse = deserialize(&mut deserializer); let response = result.unwrap(); - let (_, citations, tool_calls) = response.message(); + let (_, citations, tool_calls) = response.message().expect("assistant message"); let CompletionResponse { id, finish_reason, diff --git a/rig/rig-core/src/providers/cohere/streaming.rs b/rig/rig-core/src/providers/cohere/streaming.rs index a8b3944bf..fbedc700f 100644 --- a/rig/rig-core/src/providers/cohere/streaming.rs +++ b/rig/rig-core/src/providers/cohere/streaming.rs @@ -134,7 +134,11 @@ where let body = serde_json::to_vec(&request)?; - let req = self.client.post("/v2/chat")?.body(body).unwrap(); + let req = self + .client + .post("/v2/chat")? + .body(body) + .map_err(|e| CompletionError::HttpError(e.into()))?; let mut event_source = GenericEventSource::new(self.client.clone(), req); diff --git a/rig/rig-core/src/providers/copilot/mod.rs b/rig/rig-core/src/providers/copilot/mod.rs index 465307736..98f3cf4dc 100644 --- a/rig/rig-core/src/providers/copilot/mod.rs +++ b/rig/rig-core/src/providers/copilot/mod.rs @@ -201,8 +201,9 @@ impl ProviderBuilder for CopilotBuilder { impl ProviderClient for Client { type Input = CopilotAuth; + type Error = crate::client::ProviderClientError; - fn from_env() -> Self { + fn from_env() -> Result { let mut builder = Self::builder(); fn get(name: &str) -> Option { std::env::var(name).ok() @@ -213,16 +214,19 @@ impl ProviderClient for Client { } if let Some(api_key) = env_api_key(&get) { - builder.api_key(api_key).build().unwrap() + builder.api_key(api_key).build().map_err(Into::into) } else if let Some(access_token) = env_github_access_token(&get) { - builder.github_access_token(access_token).build().unwrap() + builder + .github_access_token(access_token) + .build() + .map_err(Into::into) } else { - builder.oauth().build().unwrap() + builder.oauth().build().map_err(Into::into) } } - fn from_val(input: Self::Input) -> Self { - Self::builder().api_key(input).build().unwrap() + fn from_val(input: Self::Input) -> Result { + Self::builder().api_key(input).build().map_err(Into::into) } } @@ -808,8 +812,14 @@ where let auth = self.auth_context().await?; let headers = default_headers(&auth.api_key, initiator, has_vision); let mut request_json = serde_json::to_value(&request)?; - request_json["stream"] = json!(true); - request_json["stream_options"] = json!({ "include_usage": true }); + let request_object = request_json.as_object_mut().ok_or_else(|| { + CompletionError::ResponseError("copilot request body must be a JSON object".into()) + })?; + request_object.insert("stream".to_owned(), json!(true)); + request_object.insert( + "stream_options".to_owned(), + json!({ "include_usage": true }), + ); let req = apply_headers( post_with_auth_base(&self.client, &auth, "/chat/completions", Transport::Sse)?, @@ -1159,14 +1169,18 @@ where "input": documents, }); + let body_object = body.as_object_mut().ok_or_else(|| { + EmbeddingError::ResponseError("embedding request body must be a JSON object".into()) + })?; + if self.ndims > 0 && self.model.as_str() != TEXT_EMBEDDING_ADA_002 { - body["dimensions"] = json!(self.ndims); + body_object.insert("dimensions".to_owned(), json!(self.ndims)); } if let Some(encoding_format) = &self.encoding_format { - body["encoding_format"] = json!(encoding_format); + body_object.insert("encoding_format".to_owned(), json!(encoding_format)); } if let Some(user) = &self.user { - body["user"] = json!(user); + body_object.insert("user".to_owned(), json!(user)); } let req = apply_headers( diff --git a/rig/rig-core/src/providers/deepseek.rs b/rig/rig-core/src/providers/deepseek.rs index 7e609b2ce..15a26c05e 100644 --- a/rig/rig-core/src/providers/deepseek.rs +++ b/rig/rig-core/src/providers/deepseek.rs @@ -86,21 +86,22 @@ pub type ClientBuilder = client::ClientBuilder Self { - let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set"); + fn from_env() -> Result { + let api_key = crate::client::required_env_var("DEEPSEEK_API_KEY")?; let mut client_builder = Self::builder(); client_builder.headers_mut().insert( http::header::CONTENT_TYPE, http::HeaderValue::from_static("application/json"), ); let client_builder = client_builder.api_key(&api_key); - client_builder.build().unwrap() + client_builder.build().map_err(Into::into) } - fn from_val(input: Self::Input) -> Self { - Self::new(input).unwrap() + fn from_val(input: Self::Input) -> Result { + Self::new(input).map_err(Into::into) } } diff --git a/rig/rig-core/src/providers/galadriel.rs b/rig/rig-core/src/providers/galadriel.rs index 300e94203..e763a40cd 100644 --- a/rig/rig-core/src/providers/galadriel.rs +++ b/rig/rig-core/src/providers/galadriel.rs @@ -111,31 +111,31 @@ impl ClientBuilder { impl ProviderClient for Client { type Input = (String, Option); + type Error = crate::client::ProviderClientError; /// Create a new Galadriel client from the `GALADRIEL_API_KEY` environment variable, /// and optionally from the `GALADRIEL_FINE_TUNE_API_KEY` environment variable. - /// Panics if the `GALADRIEL_API_KEY` environment variable is not set. - fn from_env() -> Self { - let api_key = std::env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set"); - let fine_tune_api_key = std::env::var("GALADRIEL_FINE_TUNE_API_KEY").ok(); + fn from_env() -> Result { + let api_key = crate::client::required_env_var("GALADRIEL_API_KEY")?; + let fine_tune_api_key = crate::client::optional_env_var("GALADRIEL_FINE_TUNE_API_KEY")?; let mut builder = Self::builder().api_key(api_key); - if let Some(fine_tune_api_key) = fine_tune_api_key.as_deref() { + if let Some(fine_tune_api_key) = fine_tune_api_key { builder = builder.fine_tune_api_key(fine_tune_api_key); } - builder.build().unwrap() + builder.build().map_err(Into::into) } - fn from_val((api_key, fine_tune_api_key): Self::Input) -> Self { + fn from_val((api_key, fine_tune_api_key): Self::Input) -> Result { let mut builder = Self::builder().api_key(api_key); if let Some(fine_tune_key) = fine_tune_api_key { builder = builder.fine_tune_api_key(fine_tune_key) } - builder.build().unwrap() + builder.build().map_err(Into::into) } } diff --git a/rig/rig-core/src/providers/gemini/client.rs b/rig/rig-core/src/providers/gemini/client.rs index 31b49e689..4de591e41 100644 --- a/rig/rig-core/src/providers/gemini/client.rs +++ b/rig/rig-core/src/providers/gemini/client.rs @@ -179,31 +179,31 @@ impl ProviderBuilder for GeminiInteractionsBuilder { impl ProviderClient for Client { type Input = GeminiApiKey; + type Error = crate::client::ProviderClientError; /// Create a new Google Gemini client from the `GEMINI_API_KEY` environment variable. - /// Panics if the environment variable is not set. - fn from_env() -> Self { - let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set"); - Self::new(api_key).unwrap() + fn from_env() -> Result { + let api_key = crate::client::required_env_var("GEMINI_API_KEY")?; + Self::new(api_key).map_err(Into::into) } - fn from_val(input: Self::Input) -> Self { - Self::new(input).unwrap() + fn from_val(input: Self::Input) -> Result { + Self::new(input).map_err(Into::into) } } impl ProviderClient for InteractionsClient { type Input = GeminiApiKey; + type Error = crate::client::ProviderClientError; /// Create a new Google Gemini interactions client from the `GEMINI_API_KEY` environment variable. - /// Panics if the environment variable is not set. - fn from_env() -> Self { - let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set"); - Self::new(api_key).unwrap() + fn from_env() -> Result { + let api_key = crate::client::required_env_var("GEMINI_API_KEY")?; + Self::new(api_key).map_err(Into::into) } - fn from_val(input: Self::Input) -> Self { - Self::new(input).unwrap() + fn from_val(input: Self::Input) -> Result { + Self::new(input).map_err(Into::into) } } diff --git a/rig/rig-core/src/providers/gemini/completion.rs b/rig/rig-core/src/providers/gemini/completion.rs index 80d908848..d566aa082 100644 --- a/rig/rig-core/src/providers/gemini/completion.rs +++ b/rig/rig-core/src/providers/gemini/completion.rs @@ -2892,7 +2892,7 @@ mod tests { /// and verifies that Gemini can interpret the image in the tool result. #[tokio::test] #[ignore = "requires GEMINI_API_KEY environment variable"] - async fn test_gemini_agent_with_image_tool_result_e2e() { + async fn test_gemini_agent_with_image_tool_result_e2e() -> anyhow::Result<()> { use crate::completion::{Prompt, ToolDefinition}; use crate::prelude::*; use crate::providers::gemini; @@ -2937,7 +2937,7 @@ mod tests { } } - let client = gemini::Client::from_env(); + let client = gemini::Client::from_env()?; let agent = client .agent("gemini-3-flash-preview") @@ -2946,21 +2946,13 @@ mod tests { .build(); // This prompt should trigger the tool, which returns an image that Gemini should process - let response = agent + let response_text = agent .prompt("Please generate a test image and tell me what color the pixel is.") - .await; - - // The test passes if Gemini successfully processes the request without errors. - // The image is a 1x1 red pixel, so Gemini should be able to describe it. - assert!( - response.is_ok(), - "Gemini should successfully process tool result with image: {:?}", - response.err() - ); - - let response_text = response.unwrap(); + .await?; println!("Response: {response_text}"); // Gemini should have been able to see the image and potentially describe its color assert!(!response_text.is_empty(), "Response should not be empty"); + + Ok(()) } } diff --git a/rig/rig-core/src/providers/gemini/embedding.rs b/rig/rig-core/src/providers/gemini/embedding.rs index a711bb9ac..c209d3c80 100644 --- a/rig/rig-core/src/providers/gemini/embedding.rs +++ b/rig/rig-core/src/providers/gemini/embedding.rs @@ -96,11 +96,12 @@ where let request_body = json!({ "requests": requests }); - tracing::trace!( - target: "rig::embedding", - "Sending embedding request to Gemini API {}", - serde_json::to_string_pretty(&request_body).unwrap() - ); + if let Ok(pretty_body) = serde_json::to_string_pretty(&request_body) { + tracing::trace!( + target: "rig::embedding", + "Sending embedding request to Gemini API {pretty_body}" + ); + } let request_body = serde_json::to_vec(&request_body)?; let path = format!("/v1beta/models/{}:batchEmbedContents", self.model); diff --git a/rig/rig-core/src/providers/gemini/interactions_api/mod.rs b/rig/rig-core/src/providers/gemini/interactions_api/mod.rs index cba503db2..238282a4e 100644 --- a/rig/rig-core/src/providers/gemini/interactions_api/mod.rs +++ b/rig/rig-core/src/providers/gemini/interactions_api/mod.rs @@ -914,7 +914,9 @@ pub mod interactions_api_types { .iter() .position(|exchange| exchange.call_id.as_deref() == Some(call_id)) { - exchanges[index].calls.push(call.clone()); + if let Some(exchange) = exchanges.get_mut(index) { + exchange.calls.push(call.clone()); + } index } else { exchanges.push(GoogleSearchExchange { @@ -940,7 +942,9 @@ pub mod interactions_api_types { .iter() .position(|exchange| exchange.call_id.as_deref() == Some(call_id)) { - exchanges[index].results.push(result.clone()); + if let Some(exchange) = exchanges.get_mut(index) { + exchange.results.push(result.clone()); + } } else { exchanges.push(GoogleSearchExchange { call_id: Some(call_id.clone()), @@ -949,7 +953,9 @@ pub mod interactions_api_types { }); } } else if let Some(index) = last_call_index { - exchanges[index].results.push(result.clone()); + if let Some(exchange) = exchanges.get_mut(index) { + exchange.results.push(result.clone()); + } } else { exchanges.push(GoogleSearchExchange { call_id: None, @@ -1014,7 +1020,9 @@ pub mod interactions_api_types { .iter() .position(|exchange| exchange.call_id.as_deref() == Some(call_id)) { - exchanges[index].calls.push(call.clone()); + if let Some(exchange) = exchanges.get_mut(index) { + exchange.calls.push(call.clone()); + } index } else { exchanges.push(UrlContextExchange { @@ -1040,7 +1048,9 @@ pub mod interactions_api_types { .iter() .position(|exchange| exchange.call_id.as_deref() == Some(call_id)) { - exchanges[index].results.push(result.clone()); + if let Some(exchange) = exchanges.get_mut(index) { + exchange.results.push(result.clone()); + } } else { exchanges.push(UrlContextExchange { call_id: Some(call_id.clone()), @@ -1049,7 +1059,9 @@ pub mod interactions_api_types { }); } } else if let Some(index) = last_call_index { - exchanges[index].results.push(result.clone()); + if let Some(exchange) = exchanges.get_mut(index) { + exchange.results.push(result.clone()); + } } else { exchanges.push(UrlContextExchange { call_id: None, @@ -1114,7 +1126,9 @@ pub mod interactions_api_types { .iter() .position(|exchange| exchange.call_id.as_deref() == Some(call_id)) { - exchanges[index].calls.push(call.clone()); + if let Some(exchange) = exchanges.get_mut(index) { + exchange.calls.push(call.clone()); + } index } else { exchanges.push(CodeExecutionExchange { @@ -1140,7 +1154,9 @@ pub mod interactions_api_types { .iter() .position(|exchange| exchange.call_id.as_deref() == Some(call_id)) { - exchanges[index].results.push(result.clone()); + if let Some(exchange) = exchanges.get_mut(index) { + exchange.results.push(result.clone()); + } } else { exchanges.push(CodeExecutionExchange { call_id: Some(call_id.clone()), @@ -1149,7 +1165,9 @@ pub mod interactions_api_types { }); } } else if let Some(index) = last_call_index { - exchanges[index].results.push(result.clone()); + if let Some(exchange) = exchanges.get_mut(index) { + exchange.results.push(result.clone()); + } } else { exchanges.push(CodeExecutionExchange { call_id: None, diff --git a/rig/rig-core/src/providers/gemini/interactions_api/streaming.rs b/rig/rig-core/src/providers/gemini/interactions_api/streaming.rs index 6b4fe4c32..41500db5b 100644 --- a/rig/rig-core/src/providers/gemini/interactions_api/streaming.rs +++ b/rig/rig-core/src/providers/gemini/interactions_api/streaming.rs @@ -190,7 +190,9 @@ where let data = serde_json::from_str::(&message.data); let Ok(data) = data else { - let err = data.unwrap_err(); + let Err(err) = data else { + continue; + }; tracing::debug!("Failed to deserialize interactions SSE event: {err}"); continue; }; diff --git a/rig/rig-core/src/providers/groq.rs b/rig/rig-core/src/providers/groq.rs index ae7388ecc..fb1cdf353 100644 --- a/rig/rig-core/src/providers/groq.rs +++ b/rig/rig-core/src/providers/groq.rs @@ -91,16 +91,16 @@ pub type ClientBuilder = client::ClientBuilder Self { - let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set"); - Self::new(&api_key).unwrap() + fn from_env() -> Result { + let api_key = crate::client::required_env_var("GROQ_API_KEY")?; + Self::new(&api_key).map_err(Into::into) } - fn from_val(input: Self::Input) -> Self { - Self::new(&input).unwrap() + fn from_val(input: Self::Input) -> Result { + Self::new(&input).map_err(Into::into) } } @@ -562,10 +562,14 @@ where } if let Some(ref additional_params) = request.additional_params { - for (key, value) in additional_params - .as_object() - .expect("Additional Parameters to OpenAI Transcription should be a map") - { + let params = additional_params.as_object().ok_or_else(|| { + TranscriptionError::RequestError(Box::new(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "additional transcription parameters must be a JSON object", + ))) + })?; + + for (key, value) in params { body = body.text(key.to_owned(), value.to_string()); } } @@ -574,9 +578,9 @@ where .client .post("/audio/transcriptions")? .body(body) - .unwrap(); + .map_err(|e| TranscriptionError::HttpError(e.into()))?; - let response = self.client.send_multipart::(req).await.unwrap(); + let response = self.client.send_multipart::(req).await?; let status = response.status(); let response_body = response.into_body().into_future().await?.to_vec(); diff --git a/rig/rig-core/src/providers/huggingface/client.rs b/rig/rig-core/src/providers/huggingface/client.rs index dc127b89a..685256756 100644 --- a/rig/rig-core/src/providers/huggingface/client.rs +++ b/rig/rig-core/src/providers/huggingface/client.rs @@ -160,17 +160,17 @@ impl ProviderBuilder for HuggingFaceBuilder { impl ProviderClient for Client { type Input = String; + type Error = crate::client::ProviderClientError; /// Create a new Huggingface client from the `HUGGINGFACE_API_KEY` environment variable. - /// Panics if the environment variable is not set. - fn from_env() -> Self { - let api_key = std::env::var("HUGGINGFACE_API_KEY").expect("HUGGINGFACE_API_KEY is not set"); + fn from_env() -> Result { + let api_key = crate::client::required_env_var("HUGGINGFACE_API_KEY")?; - Self::new(&api_key).unwrap() + Self::new(&api_key).map_err(Into::into) } - fn from_val(input: Self::Input) -> Self { - Self::new(&input).unwrap() + fn from_val(input: Self::Input) -> Result { + Self::new(&input).map_err(Into::into) } } diff --git a/rig/rig-core/src/providers/huggingface/completion.rs b/rig/rig-core/src/providers/huggingface/completion.rs index fe521313d..48de58493 100644 --- a/rig/rig-core/src/providers/huggingface/completion.rs +++ b/rig/rig-core/src/providers/huggingface/completion.rs @@ -307,13 +307,18 @@ impl TryFrom for Vec { )), })?, }), - _ => unreachable!(), + _ => Err(message::MessageError::ConversionError( + "expected tool result content while converting HuggingFace input" + .into(), + )), }) .collect::, _>>() } else { - let other_content = OneOrMany::many(other_content).expect( - "There must be other content here if there were no tool result content", - ); + let other_content = OneOrMany::many(other_content).map_err(|_| { + message::MessageError::ConversionError( + "HuggingFace user message did not contain any non-tool content".into(), + ) + })?; Ok(vec![Message::User { content: other_content.try_map(|content| match content { @@ -360,7 +365,10 @@ impl TryFrom for Vec { // Silently skip unsupported reasoning content. } message::AssistantContent::Image(_) => { - panic!("Image content is not supported on HuggingFace via Rig"); + return Err(message::MessageError::ConversionError( + "HuggingFace assistant messages do not support image content" + .into(), + )); } } } diff --git a/rig/rig-core/src/providers/hyperbolic.rs b/rig/rig-core/src/providers/hyperbolic.rs index 38345249c..586513528 100644 --- a/rig/rig-core/src/providers/hyperbolic.rs +++ b/rig/rig-core/src/providers/hyperbolic.rs @@ -79,16 +79,16 @@ pub type ClientBuilder = client::ClientBuilder Self { - let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set"); - Self::new(&api_key).unwrap() + fn from_env() -> Result { + let api_key = crate::client::required_env_var("HYPERBOLIC_API_KEY")?; + Self::new(&api_key).map_err(Into::into) } - fn from_val(input: Self::Input) -> Self { - Self::new(input).unwrap() + fn from_val(input: Self::Input) -> Result { + Self::new(input).map_err(Into::into) } } @@ -537,9 +537,13 @@ mod image_generation { type Error = ImageGenerationError; fn try_from(value: ImageGenerationResponse) -> Result { + let image = value + .images + .first() + .ok_or_else(|| ImageGenerationError::ResponseError("missing image data".into()))?; let data = BASE64_STANDARD - .decode(&value.images[0].image) - .expect("Could not decode image."); + .decode(&image.image) + .map_err(|err| ImageGenerationError::ResponseError(err.to_string()))?; Ok(Self { image: data, @@ -644,7 +648,7 @@ mod audio_generation { fn try_from(value: AudioGenerationResponse) -> Result { let data = BASE64_STANDARD .decode(&value.audio) - .expect("Could not decode audio."); + .map_err(|err| AudioGenerationError::ResponseError(err.to_string()))?; Ok(Self { audio: data, diff --git a/rig/rig-core/src/providers/internal/openai_chat_completions_compatible.rs b/rig/rig-core/src/providers/internal/openai_chat_completions_compatible.rs index 16d394b94..7f526683e 100644 --- a/rig/rig-core/src/providers/internal/openai_chat_completions_compatible.rs +++ b/rig/rig-core/src/providers/internal/openai_chat_completions_compatible.rs @@ -251,12 +251,11 @@ where if let Some(existing) = tool_calls.get(&incoming.index) && profile.should_evict(existing, &incoming) { - let evicted = tool_calls - .remove(&incoming.index) - .expect("checked above"); - yield Ok(RawStreamingChoice::ToolCall( - finalize_completed_streaming_tool_call(evicted), - )); + if let Some(evicted) = tool_calls.remove(&incoming.index) { + yield Ok(RawStreamingChoice::ToolCall( + finalize_completed_streaming_tool_call(evicted), + )); + } } let existing_tool_call = tool_calls diff --git a/rig/rig-core/src/providers/llamafile.rs b/rig/rig-core/src/providers/llamafile.rs index b0b22ad80..abd8491dd 100644 --- a/rig/rig-core/src/providers/llamafile.rs +++ b/rig/rig-core/src/providers/llamafile.rs @@ -100,26 +100,26 @@ pub type ClientBuilder = client::ClientBuilder Self { + pub fn from_url(base_url: &str) -> crate::client::ProviderClientResult { Self::builder() .api_key(Nothing) .base_url(base_url) .build() - .expect("Failed to build llamafile client") + .map_err(Into::into) } } impl ProviderClient for Client { type Input = Nothing; + type Error = crate::client::ProviderClientError; - fn from_env() -> Self { - let api_base = - std::env::var("LLAMAFILE_API_BASE_URL").expect("LLAMAFILE_API_BASE_URL not set"); + fn from_env() -> Result { + let api_base = crate::client::required_env_var("LLAMAFILE_API_BASE_URL")?; Self::from_url(&api_base) } - fn from_val(_: Self::Input) -> Self { - Self::builder().api_key(Nothing).build().unwrap() + fn from_val(_: Self::Input) -> Result { + Self::builder().api_key(Nothing).build().map_err(Into::into) } } diff --git a/rig/rig-core/src/providers/minimax.rs b/rig/rig-core/src/providers/minimax.rs index 2d0482c23..9cbd0776b 100644 --- a/rig/rig-core/src/providers/minimax.rs +++ b/rig/rig-core/src/providers/minimax.rs @@ -171,49 +171,57 @@ impl super::anthropic::completion::AnthropicCompatibleProvider for MiniMaxAnthro impl ProviderClient for Client { type Input = MiniMaxApiKey; + type Error = crate::client::ProviderClientError; - fn from_env() -> Self { - let api_key = std::env::var("MINIMAX_API_KEY").expect("MINIMAX_API_KEY not set"); + fn from_env() -> Result { + let api_key = crate::client::required_env_var("MINIMAX_API_KEY")?; let mut builder = Self::builder().api_key(api_key); - if let Ok(base_url) = std::env::var("MINIMAX_API_BASE") { + if let Some(base_url) = crate::client::optional_env_var("MINIMAX_API_BASE")? { builder = builder.base_url(base_url); } - builder.build().unwrap() + builder.build().map_err(Into::into) } - fn from_val(input: Self::Input) -> Self { - Self::new(input).unwrap() + fn from_val(input: Self::Input) -> Result { + Self::new(input).map_err(Into::into) } } impl ProviderClient for AnthropicClient { type Input = String; + type Error = crate::client::ProviderClientError; - fn from_env() -> Self { - let api_key = std::env::var("MINIMAX_API_KEY").expect("MINIMAX_API_KEY not set"); + fn from_env() -> Result { + let api_key = crate::client::required_env_var("MINIMAX_API_KEY")?; let mut builder = Self::builder().api_key(api_key); if let Some(base_url) = - anthropic_base_override("MINIMAX_ANTHROPIC_API_BASE", "MINIMAX_API_BASE") + anthropic_base_override("MINIMAX_ANTHROPIC_API_BASE", "MINIMAX_API_BASE")? { builder = builder.base_url(base_url); } - builder.build().unwrap() + builder.build().map_err(Into::into) } - fn from_val(input: Self::Input) -> Self { - Self::builder().api_key(input).build().unwrap() + fn from_val(input: Self::Input) -> Result { + Self::builder().api_key(input).build().map_err(Into::into) } } -fn anthropic_base_override(primary_env: &str, fallback_env: &str) -> Option { - let primary = std::env::var(primary_env).ok(); - let fallback = std::env::var(fallback_env).ok(); - - resolve_anthropic_base_override(primary.as_deref(), fallback.as_deref()) +fn anthropic_base_override( + primary_env: &'static str, + fallback_env: &'static str, +) -> crate::client::ProviderClientResult> { + let primary = crate::client::optional_env_var(primary_env)?; + let fallback = crate::client::optional_env_var(fallback_env)?; + + Ok(resolve_anthropic_base_override( + primary.as_deref(), + fallback.as_deref(), + )) } fn resolve_anthropic_base_override( diff --git a/rig/rig-core/src/providers/mira.rs b/rig/rig-core/src/providers/mira.rs index 394237fe3..f1bef34e8 100644 --- a/rig/rig-core/src/providers/mira.rs +++ b/rig/rig-core/src/providers/mira.rs @@ -197,16 +197,16 @@ where impl ProviderClient for Client { type Input = String; + type Error = crate::client::ProviderClientError; /// Create a new Mira client from the `MIRA_API_KEY` environment variable. - /// Panics if the environment variable is not set. - fn from_env() -> Self { - let api_key = std::env::var("MIRA_API_KEY").expect("MIRA_API_KEY not set"); - Self::new(&api_key).unwrap() + fn from_env() -> Result { + let api_key = crate::client::required_env_var("MIRA_API_KEY")?; + Self::new(&api_key).map_err(Into::into) } - fn from_val(input: Self::Input) -> Self { - Self::new(&input).unwrap() + fn from_val(input: Self::Input) -> Result { + Self::new(&input).map_err(Into::into) } } @@ -651,9 +651,12 @@ impl TryFrom for Message { type Error = CompletionError; fn try_from(value: serde_json::Value) -> Result { - let role = value["role"].as_str().ok_or_else(|| { - CompletionError::ResponseError("Message missing role field".to_owned()) - })?; + let role = value + .get("role") + .and_then(serde_json::Value::as_str) + .ok_or_else(|| { + CompletionError::ResponseError("Message missing role field".to_owned()) + })?; // Handle both string and array content formats let content = match value.get("content") { diff --git a/rig/rig-core/src/providers/mistral/client.rs b/rig/rig-core/src/providers/mistral/client.rs index 2faa3f4eb..aa37d4beb 100644 --- a/rig/rig-core/src/providers/mistral/client.rs +++ b/rig/rig-core/src/providers/mistral/client.rs @@ -64,19 +64,19 @@ impl ProviderBuilder for MistralBuilder { impl ProviderClient for Client { type Input = String; + type Error = crate::client::ProviderClientError; /// Create a new Mistral client from the `MISTRAL_API_KEY` environment variable. - /// Panics if the environment variable is not set. - fn from_env() -> Self + fn from_env() -> Result where Self: Sized, { - let api_key = std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set"); - Self::new(&api_key).unwrap() + let api_key = crate::client::required_env_var("MISTRAL_API_KEY")?; + Self::new(&api_key).map_err(Into::into) } - fn from_val(input: Self::Input) -> Self { - Self::new(&input).unwrap() + fn from_val(input: Self::Input) -> Result { + Self::new(&input).map_err(Into::into) } } diff --git a/rig/rig-core/src/providers/mistral/completion.rs b/rig/rig-core/src/providers/mistral/completion.rs index a0224115e..46053b77e 100644 --- a/rig/rig-core/src/providers/mistral/completion.rs +++ b/rig/rig-core/src/providers/mistral/completion.rs @@ -165,7 +165,9 @@ impl TryFrom for Vec { // reasoning items. Silently skip to avoid crashing the process. } message::AssistantContent::Image(_) => { - panic!("Image content is not currently supported on Mistral via Rig"); + return Err(message::MessageError::ConversionError( + "Mistral assistant messages do not support image content".into(), + )); } } } diff --git a/rig/rig-core/src/providers/moonshot.rs b/rig/rig-core/src/providers/moonshot.rs index 647129e3a..b1afc0e47 100644 --- a/rig/rig-core/src/providers/moonshot.rs +++ b/rig/rig-core/src/providers/moonshot.rs @@ -157,39 +157,40 @@ pub type AnthropicClientBuilder = impl ProviderClient for Client { type Input = String; + type Error = crate::client::ProviderClientError; /// Create a new Moonshot client from the `MOONSHOT_API_KEY` environment variable. - /// Panics if the environment variable is not set. - fn from_env() -> Self { - let api_key = std::env::var("MOONSHOT_API_KEY").expect("MOONSHOT_API_KEY not set"); + fn from_env() -> Result { + let api_key = crate::client::required_env_var("MOONSHOT_API_KEY")?; let mut builder = Self::builder().api_key(&api_key); - if let Ok(base_url) = std::env::var("MOONSHOT_API_BASE") { + if let Some(base_url) = crate::client::optional_env_var("MOONSHOT_API_BASE")? { builder = builder.base_url(base_url); } - builder.build().unwrap() + builder.build().map_err(Into::into) } - fn from_val(input: Self::Input) -> Self { - Self::new(&input).unwrap() + fn from_val(input: Self::Input) -> Result { + Self::new(&input).map_err(Into::into) } } impl ProviderClient for AnthropicClient { type Input = String; + type Error = crate::client::ProviderClientError; - fn from_env() -> Self { - let api_key = std::env::var("MOONSHOT_API_KEY").expect("MOONSHOT_API_KEY not set"); + fn from_env() -> Result { + let api_key = crate::client::required_env_var("MOONSHOT_API_KEY")?; let mut builder = Self::builder().api_key(api_key); if let Some(base_url) = - anthropic_base_override("MOONSHOT_ANTHROPIC_API_BASE", "MOONSHOT_API_BASE") + anthropic_base_override("MOONSHOT_ANTHROPIC_API_BASE", "MOONSHOT_API_BASE")? { builder = builder.base_url(base_url); } - builder.build().unwrap() + builder.build().map_err(Into::into) } - fn from_val(input: Self::Input) -> Self { - Self::builder().api_key(input).build().unwrap() + fn from_val(input: Self::Input) -> Result { + Self::builder().api_key(input).build().map_err(Into::into) } } @@ -240,11 +241,17 @@ impl super::anthropic::completion::AnthropicCompatibleProvider for MoonshotAnthr } } -fn anthropic_base_override(primary_env: &str, fallback_env: &str) -> Option { - let primary = std::env::var(primary_env).ok(); - let fallback = std::env::var(fallback_env).ok(); - - resolve_anthropic_base_override(primary.as_deref(), fallback.as_deref()) +fn anthropic_base_override( + primary_env: &'static str, + fallback_env: &'static str, +) -> crate::client::ProviderClientResult> { + let primary = crate::client::optional_env_var(primary_env)?; + let fallback = crate::client::optional_env_var(fallback_env)?; + + Ok(resolve_anthropic_base_override( + primary.as_deref(), + fallback.as_deref(), + )) } fn resolve_anthropic_base_override( diff --git a/rig/rig-core/src/providers/ollama.rs b/rig/rig-core/src/providers/ollama.rs index 0ed26281b..e6ae33783 100644 --- a/rig/rig-core/src/providers/ollama.rs +++ b/rig/rig-core/src/providers/ollama.rs @@ -155,12 +155,13 @@ pub type ClientBuilder = client::ClientBuilder Self { - let api_base = std::env::var("OLLAMA_API_BASE_URL") - .unwrap_or_else(|_| OLLAMA_API_BASE_URL.to_string()); + fn from_env() -> Result { + let api_base = crate::client::optional_env_var("OLLAMA_API_BASE_URL")? + .unwrap_or_else(|| OLLAMA_API_BASE_URL.to_string()); - let api_key: OllamaApiKey = std::env::var("OLLAMA_API_KEY") + let api_key = crate::client::optional_env_var("OLLAMA_API_KEY")? .map(OllamaApiKey::from) .unwrap_or_default(); @@ -168,14 +169,11 @@ impl ProviderClient for Client { .api_key(api_key) .base_url(&api_base) .build() - .expect("failed to build Ollama client from environment") + .map_err(Into::into) } - fn from_val(api_key: Self::Input) -> Self { - Self::builder() - .api_key(api_key) - .build() - .expect("failed to build Ollama client") + fn from_val(api_key: Self::Input) -> Result { + Self::builder().api_key(api_key).build().map_err(Into::into) } } @@ -759,7 +757,9 @@ where name: None, tool_calls: tool_calls_final.clone() }; - span.record("gen_ai.output.messages", serde_json::to_string(&vec![message]).unwrap()); + if let Ok(serialized_message) = serde_json::to_string(&vec![message]) { + span.record("gen_ai.output.messages", serialized_message); + } yield RawStreamingChoice::FinalResponse( StreamingCompletionResponse { total_duration: response.total_duration, @@ -968,7 +968,9 @@ impl TryFrom for Vec { content: content_string, }) } - _ => unreachable!(), + _ => Err(crate::message::MessageError::ConversionError( + "expected tool result content while converting Ollama input".into(), + )), }) .collect::, _>>() } else { @@ -1086,10 +1088,14 @@ impl From for crate::completion::Message { ), ); } - crate::completion::Message::Assistant { - id: None, - content: OneOrMany::many(assistant_contents).unwrap(), - } + let content = + OneOrMany::from_iter_optional(assistant_contents).unwrap_or_else(|| { + OneOrMany::one(crate::completion::message::AssistantContent::Text(Text { + text: String::new(), + })) + }); + + crate::completion::Message::Assistant { id: None, content } } // System and ToolResult are converted to User message as needed. Message::System { content, .. } => crate::completion::Message::User { diff --git a/rig/rig-core/src/providers/openai/client.rs b/rig/rig-core/src/providers/openai/client.rs index a6f282737..f4cf5416c 100644 --- a/rig/rig-core/src/providers/openai/client.rs +++ b/rig/rig-core/src/providers/openai/client.rs @@ -207,12 +207,12 @@ where impl ProviderClient for Client { type Input = OpenAIApiKey; + type Error = crate::client::ProviderClientError; /// Create a new OpenAI Responses API client from the `OPENAI_API_KEY` environment variable. - /// Panics if the environment variable is not set. - fn from_env() -> Self { - let base_url: Option = std::env::var("OPENAI_BASE_URL").ok(); - let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); + fn from_env() -> Result { + let base_url = crate::client::optional_env_var("OPENAI_BASE_URL")?; + let api_key = crate::client::required_env_var("OPENAI_API_KEY")?; let mut builder = Client::builder().api_key(&api_key); @@ -220,22 +220,22 @@ impl ProviderClient for Client { builder = builder.base_url(&base); } - builder.build().unwrap() + builder.build().map_err(Into::into) } - fn from_val(input: Self::Input) -> Self { - Self::new(input).unwrap() + fn from_val(input: Self::Input) -> Result { + Self::new(input).map_err(Into::into) } } impl ProviderClient for CompletionsClient { type Input = OpenAIApiKey; + type Error = crate::client::ProviderClientError; /// Create a new OpenAI Completions API client from the `OPENAI_API_KEY` environment variable. - /// Panics if the environment variable is not set. - fn from_env() -> Self { - let base_url: Option = std::env::var("OPENAI_BASE_URL").ok(); - let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); + fn from_env() -> Result { + let base_url = crate::client::optional_env_var("OPENAI_BASE_URL")?; + let api_key = crate::client::required_env_var("OPENAI_API_KEY")?; let mut builder = CompletionsClient::builder().api_key(&api_key); @@ -243,11 +243,11 @@ impl ProviderClient for CompletionsClient { builder = builder.base_url(&base); } - builder.build().unwrap() + builder.build().map_err(Into::into) } - fn from_val(input: Self::Input) -> Self { - Self::new(input).unwrap() + fn from_val(input: Self::Input) -> Result { + Self::new(input).map_err(Into::into) } } diff --git a/rig/rig-core/src/providers/openai/completion/mod.rs b/rig/rig-core/src/providers/openai/completion/mod.rs index 32a108ee9..f8fb0b600 100644 --- a/rig/rig-core/src/providers/openai/completion/mod.rs +++ b/rig/rig-core/src/providers/openai/completion/mod.rs @@ -550,7 +550,9 @@ impl TryFrom> for Vec { .into_iter() .map(|content| match content { message::UserContent::ToolResult(tool_result) => tool_result.try_into(), - _ => unreachable!(), + _ => Err(message::MessageError::ConversionError( + "expected tool result content while converting OpenAI input".into(), + )), }) .collect::, _>>() } else { @@ -559,8 +561,11 @@ impl TryFrom> for Vec { .map(|content| content.try_into()) .collect::, _>>()?; - let other_content = OneOrMany::many(other_content) - .expect("There must be other content here if there were no tool result content"); + let other_content = OneOrMany::many(other_content).map_err(|_| { + message::MessageError::ConversionError( + "OpenAI user message did not contain any non-tool content".into(), + ) + })?; Ok(vec![Message::User { content: other_content, @@ -586,9 +591,10 @@ impl TryFrom> for Vec { reasoning_text.push_str(&reasoning.display_text()); } message::AssistantContent::Image(_) => { - panic!( - "The OpenAI Completions API doesn't support image content in assistant messages!" - ); + return Err(message::MessageError::ConversionError( + "OpenAI assistant messages do not support image content in chat completions" + .into(), + )); } } } diff --git a/rig/rig-core/src/providers/openai/completion/streaming.rs b/rig/rig-core/src/providers/openai/completion/streaming.rs index ff43ba6c3..29d7814c0 100644 --- a/rig/rig-core/src/providers/openai/completion/streaming.rs +++ b/rig/rig-core/src/providers/openai/completion/streaming.rs @@ -102,9 +102,8 @@ where strict_tools: self.strict_tools, tool_result_array_content: self.tool_result_array_content, })?; - let request_messages = serde_json::to_string(&request.messages) - .expect("Converting to JSON from a Rust struct shouldn't fail"); - let mut request_as_json = serde_json::to_value(request).expect("this should never fail"); + let request_messages = serde_json::to_string(&request.messages)?; + let mut request_as_json = serde_json::to_value(request)?; request_as_json = merge( request_as_json, diff --git a/rig/rig-core/src/providers/openai/embedding.rs b/rig/rig-core/src/providers/openai/embedding.rs index 2c3e90719..3060106db 100644 --- a/rig/rig-core/src/providers/openai/embedding.rs +++ b/rig/rig-core/src/providers/openai/embedding.rs @@ -113,16 +113,20 @@ where "input": documents, }); + let body_object = body.as_object_mut().ok_or_else(|| { + EmbeddingError::ResponseError("embedding request body must be a JSON object".into()) + })?; + if self.ndims > 0 && self.model.as_str() != TEXT_EMBEDDING_ADA_002 { - body["dimensions"] = json!(self.ndims); + body_object.insert("dimensions".to_owned(), json!(self.ndims)); } if let Some(encoding_format) = &self.encoding_format { - body["encoding_format"] = json!(encoding_format); + body_object.insert("encoding_format".to_owned(), json!(encoding_format)); } if let Some(user) = &self.user { - body["user"] = json!(user); + body_object.insert("user".to_owned(), json!(user)); } let body = serde_json::to_vec(&body)?; diff --git a/rig/rig-core/src/providers/openai/image_generation.rs b/rig/rig-core/src/providers/openai/image_generation.rs index 17c5a5197..ff2a52c36 100644 --- a/rig/rig-core/src/providers/openai/image_generation.rs +++ b/rig/rig-core/src/providers/openai/image_generation.rs @@ -33,11 +33,16 @@ impl TryFrom type Error = ImageGenerationError; fn try_from(value: ImageGenerationResponse) -> Result { - let b64_json = value.data[0].b64_json.clone(); + let b64_json = value + .data + .first() + .ok_or_else(|| ImageGenerationError::ResponseError("missing image data".into()))? + .b64_json + .clone(); let bytes = BASE64_STANDARD .decode(&b64_json) - .expect("Failed to decode b64"); + .map_err(|err| ImageGenerationError::ResponseError(err.to_string()))?; Ok(image_generation::ImageGenerationResponse { image: bytes, diff --git a/rig/rig-core/src/providers/openai/responses_api/mod.rs b/rig/rig-core/src/providers/openai/responses_api/mod.rs index 98aaa5e17..c3388ef8b 100644 --- a/rig/rig-core/src/providers/openai/responses_api/mod.rs +++ b/rig/rig-core/src/providers/openai/responses_api/mod.rs @@ -1666,7 +1666,10 @@ impl TryFrom for Vec { } }, }), - _ => unreachable!(), + _ => Err(MessageError::ConversionError( + "expected tool result content while converting Responses API input" + .into(), + )), }) .collect::, _>>() } else { @@ -1761,7 +1764,9 @@ impl TryFrom for Vec { message::UserContent::Audio(_) => Err(MessageError::ConversionError( "Audio must be base64 encoded data".into(), )), - _ => unreachable!(), + _ => Err(MessageError::ConversionError( + "Unsupported user content for OpenAI Responses API".into(), + )), }) .collect::, _>>()?; diff --git a/rig/rig-core/src/providers/openai/responses_api/streaming.rs b/rig/rig-core/src/providers/openai/responses_api/streaming.rs index 8b95a8e68..707b5434b 100644 --- a/rig/rig-core/src/providers/openai/responses_api/streaming.rs +++ b/rig/rig-core/src/providers/openai/responses_api/streaming.rs @@ -346,7 +346,12 @@ impl RawChoiceAccumulator { if options.errors_on_terminal_response() => { let error_message = response_chunk_error_message(&kind, &response, provider_name) - .expect("terminal response should have an error message"); + .unwrap_or_else(|| { + format!( + "{provider_name} returned terminal response {:?} without an error message", + kind + ) + }); Err(CompletionError::ProviderError(error_message)) } _ => Ok(()), @@ -507,11 +512,13 @@ pub(crate) fn raw_choices_from_sse_body( Some("response.completed") | Some("response.failed") | Some("response.incomplete") => { if let Some(response) = value.get("response").cloned() { let response = serde_json::from_value::(response)?; - let kind = match value.get("type").and_then(serde_json::Value::as_str) { - Some("response.completed") => ResponseChunkKind::ResponseCompleted, - Some("response.failed") => ResponseChunkKind::ResponseFailed, - Some("response.incomplete") => ResponseChunkKind::ResponseIncomplete, - _ => unreachable!("filtered above"), + let Some(kind) = (match value.get("type").and_then(serde_json::Value::as_str) { + Some("response.completed") => Some(ResponseChunkKind::ResponseCompleted), + Some("response.failed") => Some(ResponseChunkKind::ResponseFailed), + Some("response.incomplete") => Some(ResponseChunkKind::ResponseIncomplete), + _ => None, + }) else { + continue; }; accumulator.record_response_chunk(kind, response, provider_name, options)?; } @@ -649,7 +656,13 @@ where let data = serde_json::from_str::(&evt.data); let Ok(data) = data else { - debug!("Couldn't deserialize SSE data as StreamingCompletionChunk: {:?}", data.unwrap_err()); + let Err(err) = data else { + continue; + }; + debug!( + "Couldn't deserialize SSE data as StreamingCompletionChunk: {:?}", + err + ); continue; }; diff --git a/rig/rig-core/src/providers/openai/transcription.rs b/rig/rig-core/src/providers/openai/transcription.rs index eaa5d77a3..bc57588ff 100644 --- a/rig/rig-core/src/providers/openai/transcription.rs +++ b/rig/rig-core/src/providers/openai/transcription.rs @@ -84,10 +84,14 @@ where } if let Some(ref additional_params) = request.additional_params { - for (key, value) in additional_params - .as_object() - .expect("Additional Parameters to OpenAI Transcription should be a map") - { + let params = additional_params.as_object().ok_or_else(|| { + TranscriptionError::RequestError(Box::new(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "additional transcription parameters must be a JSON object", + ))) + })?; + + for (key, value) in params { body = body.text(key.to_owned(), value.to_string()); } } @@ -96,9 +100,9 @@ where .client .post("/audio/transcriptions")? .body(body) - .unwrap(); + .map_err(|e| TranscriptionError::HttpError(e.into()))?; - let response = self.client.send_multipart::(req).await.unwrap(); + let response = self.client.send_multipart::(req).await?; let status = response.status(); let response_body = response.into_body().into_future().await?.to_vec(); diff --git a/rig/rig-core/src/providers/openrouter/client.rs b/rig/rig-core/src/providers/openrouter/client.rs index e83ae9170..fb802372e 100644 --- a/rig/rig-core/src/providers/openrouter/client.rs +++ b/rig/rig-core/src/providers/openrouter/client.rs @@ -66,17 +66,17 @@ impl ProviderBuilder for OpenRouterExtBuilder { impl ProviderClient for Client { type Input = OpenRouterApiKey; + type Error = crate::client::ProviderClientError; /// Create a new openrouter client from the `OPENROUTER_API_KEY` environment variable. - /// Panics if the environment variable is not set. - fn from_env() -> Self { - let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set"); + fn from_env() -> Result { + let api_key = crate::client::required_env_var("OPENROUTER_API_KEY")?; - Self::new(&api_key).unwrap() + Self::new(&api_key).map_err(Into::into) } - fn from_val(input: Self::Input) -> Self { - Self::new(input).unwrap() + fn from_val(input: Self::Input) -> Result { + Self::new(input).map_err(Into::into) } } diff --git a/rig/rig-core/src/providers/openrouter/completion.rs b/rig/rig-core/src/providers/openrouter/completion.rs index da83e42ff..edd85de05 100644 --- a/rig/rig-core/src/providers/openrouter/completion.rs +++ b/rig/rig-core/src/providers/openrouter/completion.rs @@ -1194,7 +1194,9 @@ impl TryFrom> for Vec { .collect::>() .join("\n"), }), - _ => unreachable!(), + _ => Err(message::MessageError::ConversionError( + "expected tool result content while converting OpenRouter input".into(), + )), }) .collect::, _>>() } else { @@ -1203,8 +1205,11 @@ impl TryFrom> for Vec { .map(|content| content.try_into()) .collect::, _>>()?; - let content = OneOrMany::many(user_content) - .expect("There must be content here if there were no tool result content"); + let content = OneOrMany::many(user_content).map_err(|_| { + message::MessageError::ConversionError( + "OpenRouter user message did not contain any non-tool content".into(), + ) + })?; Ok(vec![Message::User { content, diff --git a/rig/rig-core/src/providers/openrouter/embedding.rs b/rig/rig-core/src/providers/openrouter/embedding.rs index 24459ad41..399842331 100644 --- a/rig/rig-core/src/providers/openrouter/embedding.rs +++ b/rig/rig-core/src/providers/openrouter/embedding.rs @@ -86,16 +86,20 @@ where "input": documents, }); + let body_object = body.as_object_mut().ok_or_else(|| { + EmbeddingError::ResponseError("embedding request body must be a JSON object".into()) + })?; + if self.ndims > 0 { - body["dimensions"] = json!(self.ndims); + body_object.insert("dimensions".to_owned(), json!(self.ndims)); } if let Some(encoding_format) = &self.encoding_format { - body["encoding_format"] = json!(encoding_format); + body_object.insert("encoding_format".to_owned(), json!(encoding_format)); } if let Some(user) = &self.user { - body["user"] = json!(user); + body_object.insert("user".to_owned(), json!(user)); } let body = serde_json::to_vec(&body)?; diff --git a/rig/rig-core/src/providers/perplexity.rs b/rig/rig-core/src/providers/perplexity.rs index a1b670521..f30db9a92 100644 --- a/rig/rig-core/src/providers/perplexity.rs +++ b/rig/rig-core/src/providers/perplexity.rs @@ -84,16 +84,16 @@ pub type ClientBuilder = impl ProviderClient for Client { type Input = String; + type Error = crate::client::ProviderClientError; /// Create a new Perplexity client from the `PERPLEXITY_API_KEY` environment variable. - /// Panics if the environment variable is not set. - fn from_env() -> Self { - let api_key = std::env::var("PERPLEXITY_API_KEY").expect("PERPLEXITY_API_KEY not set"); - Self::new(&api_key).unwrap() + fn from_env() -> Result { + let api_key = crate::client::required_env_var("PERPLEXITY_API_KEY")?; + Self::new(&api_key).map_err(Into::into) } - fn from_val(input: Self::Input) -> Self { - Self::new(&input).unwrap() + fn from_val(input: Self::Input) -> Result { + Self::new(&input).map_err(Into::into) } } diff --git a/rig/rig-core/src/providers/together/client.rs b/rig/rig-core/src/providers/together/client.rs index fb84096a9..3f8d63270 100644 --- a/rig/rig-core/src/providers/together/client.rs +++ b/rig/rig-core/src/providers/together/client.rs @@ -60,16 +60,16 @@ impl ProviderBuilder for TogetherExtBuilder { impl ProviderClient for Client { type Input = String; + type Error = crate::client::ProviderClientError; /// Create a new Together AI client from the `TOGETHER_API_KEY` environment variable. - /// Panics if the environment variable is not set. - fn from_env() -> Self { - let api_key = std::env::var("TOGETHER_API_KEY").expect("TOGETHER_API_KEY not set"); - Self::new(&api_key).unwrap() + fn from_env() -> Result { + let api_key = crate::client::required_env_var("TOGETHER_API_KEY")?; + Self::new(&api_key).map_err(Into::into) } - fn from_val(input: Self::Input) -> Self { - Self::new(&input).unwrap() + fn from_val(input: Self::Input) -> Result { + Self::new(&input).map_err(Into::into) } } diff --git a/rig/rig-core/src/providers/voyageai.rs b/rig/rig-core/src/providers/voyageai.rs index 34cdf99b0..6e04a8d42 100644 --- a/rig/rig-core/src/providers/voyageai.rs +++ b/rig/rig-core/src/providers/voyageai.rs @@ -67,16 +67,16 @@ pub type ClientBuilder = client::ClientBuilder Self { - let api_key = std::env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY not set"); - Self::new(&api_key).unwrap() + fn from_env() -> Result { + let api_key = crate::client::required_env_var("VOYAGE_API_KEY")?; + Self::new(&api_key).map_err(Into::into) } - fn from_val(input: Self::Input) -> Self { - Self::new(&input).unwrap() + fn from_val(input: Self::Input) -> Result { + Self::new(&input).map_err(Into::into) } } diff --git a/rig/rig-core/src/providers/xai/client.rs b/rig/rig-core/src/providers/xai/client.rs index f16f5ee1a..e3ac6f69f 100644 --- a/rig/rig-core/src/providers/xai/client.rs +++ b/rig/rig-core/src/providers/xai/client.rs @@ -59,16 +59,16 @@ impl ProviderBuilder for XAiExtBuilder { impl ProviderClient for Client { type Input = String; + type Error = crate::client::ProviderClientError; /// Create a new xAI client from the `XAI_API_KEY` environment variable. - /// Panics if the environment variable is not set. - fn from_env() -> Self { - let api_key = std::env::var("XAI_API_KEY").expect("XAI_API_KEY not set"); - Self::new(&api_key).unwrap() + fn from_env() -> Result { + let api_key = crate::client::required_env_var("XAI_API_KEY")?; + Self::new(&api_key).map_err(Into::into) } - fn from_val(input: Self::Input) -> Self { - Self::new(&input).unwrap() + fn from_val(input: Self::Input) -> Result { + Self::new(&input).map_err(Into::into) } } #[cfg(test)] diff --git a/rig/rig-core/src/providers/zai.rs b/rig/rig-core/src/providers/zai.rs index daa94a96d..6df8fa87a 100644 --- a/rig/rig-core/src/providers/zai.rs +++ b/rig/rig-core/src/providers/zai.rs @@ -169,47 +169,55 @@ impl super::anthropic::completion::AnthropicCompatibleProvider for ZAiAnthropicE impl ProviderClient for Client { type Input = ZAiApiKey; + type Error = crate::client::ProviderClientError; - fn from_env() -> Self { - let api_key = std::env::var("ZAI_API_KEY").expect("ZAI_API_KEY not set"); + fn from_env() -> Result { + let api_key = crate::client::required_env_var("ZAI_API_KEY")?; let mut builder = Self::builder().api_key(api_key); - if let Ok(base_url) = std::env::var("ZAI_API_BASE") { + if let Some(base_url) = crate::client::optional_env_var("ZAI_API_BASE")? { builder = builder.base_url(base_url); } - builder.build().unwrap() + builder.build().map_err(Into::into) } - fn from_val(input: Self::Input) -> Self { - Self::new(input).unwrap() + fn from_val(input: Self::Input) -> Result { + Self::new(input).map_err(Into::into) } } impl ProviderClient for AnthropicClient { type Input = String; + type Error = crate::client::ProviderClientError; - fn from_env() -> Self { - let api_key = std::env::var("ZAI_API_KEY").expect("ZAI_API_KEY not set"); + fn from_env() -> Result { + let api_key = crate::client::required_env_var("ZAI_API_KEY")?; let mut builder = Self::builder().api_key(api_key); - if let Some(base_url) = anthropic_base_override("ZAI_ANTHROPIC_API_BASE", "ZAI_API_BASE") { + if let Some(base_url) = anthropic_base_override("ZAI_ANTHROPIC_API_BASE", "ZAI_API_BASE")? { builder = builder.base_url(base_url); } - builder.build().unwrap() + builder.build().map_err(Into::into) } - fn from_val(input: Self::Input) -> Self { - Self::builder().api_key(input).build().unwrap() + fn from_val(input: Self::Input) -> Result { + Self::builder().api_key(input).build().map_err(Into::into) } } -fn anthropic_base_override(primary_env: &str, fallback_env: &str) -> Option { - let primary = std::env::var(primary_env).ok(); - let fallback = std::env::var(fallback_env).ok(); - - resolve_anthropic_base_override(primary.as_deref(), fallback.as_deref()) +fn anthropic_base_override( + primary_env: &'static str, + fallback_env: &'static str, +) -> crate::client::ProviderClientResult> { + let primary = crate::client::optional_env_var(primary_env)?; + let fallback = crate::client::optional_env_var(fallback_env)?; + + Ok(resolve_anthropic_base_override( + primary.as_deref(), + fallback.as_deref(), + )) } fn resolve_anthropic_base_override( diff --git a/rig/rig-core/src/streaming.rs b/rig/rig-core/src/streaming.rs index 07b47a03f..dd285af73 100644 --- a/rig/rig-core/src/streaming.rs +++ b/rig/rig-core/src/streaming.rs @@ -45,11 +45,11 @@ impl PauseControl { } pub fn pause(&self) { - self.paused_tx.send(true).unwrap(); + let _ = self.paused_tx.send(true); } pub fn resume(&self) { - self.paused_tx.send(false).unwrap(); + let _ = self.paused_tx.send(false); } pub fn is_paused(&self) -> bool { @@ -331,8 +331,11 @@ where stream.assistant_items.push(AssistantContent::text("")); } - stream.choice = OneOrMany::many(std::mem::take(&mut stream.assistant_items)) - .expect("There should be at least one assistant message"); + if let Some(choice) = + OneOrMany::from_iter_optional(std::mem::take(&mut stream.assistant_items)) + { + stream.choice = choice; + } Poll::Ready(None) } @@ -571,9 +574,10 @@ where println!("\nResult: {res}"); } Ok(StreamedAssistantContent::Final(res)) => { - let json_res = serde_json::to_string_pretty(&res).unwrap(); - println!(); - tracing::info!("Final result: {json_res}"); + if let Ok(json_res) = serde_json::to_string_pretty(&res) { + println!(); + tracing::info!("Final result: {json_res}"); + } } Ok(StreamedAssistantContent::Reasoning(reasoning)) => { if !is_reasoning { diff --git a/rig/rig-core/src/telemetry/mod.rs b/rig/rig-core/src/telemetry/mod.rs index 9a79003bb..9d62ec32d 100644 --- a/rig/rig-core/src/telemetry/mod.rs +++ b/rig/rig-core/src/telemetry/mod.rs @@ -98,10 +98,9 @@ impl SpanCombinator for tracing::Span { return; } - let input_as_json_string = - serde_json::to_string(input).expect("Serializing a Rust type to JSON should not break"); - - self.record("gen_ai.input.messages", input_as_json_string); + if let Ok(input_as_json_string) = serde_json::to_string(input) { + self.record("gen_ai.input.messages", input_as_json_string); + } } fn record_model_output(&self, output: &T) @@ -112,9 +111,8 @@ impl SpanCombinator for tracing::Span { return; } - let output_as_json_string = serde_json::to_string(output) - .expect("Serializing a Rust type to JSON should not break"); - - self.record("gen_ai.output.messages", output_as_json_string); + if let Ok(output_as_json_string) = serde_json::to_string(output) { + self.record("gen_ai.output.messages", output_as_json_string); + } } } diff --git a/rig/rig-core/src/tool/mod.rs b/rig/rig-core/src/tool/mod.rs index dacba1780..453a3bdc6 100644 --- a/rig/rig-core/src/tool/mod.rs +++ b/rig/rig-core/src/tool/mod.rs @@ -352,7 +352,7 @@ impl ToolSet { if let Some(tool) = self.tools.get(toolname) { tracing::debug!(target: "rig", "Calling tool {toolname} with args:\n{}", - serde_json::to_string_pretty(&args).unwrap() + args ); Ok(tool.call(args).await?) } else { diff --git a/rig/rig-core/src/vector_store/in_memory_store.rs b/rig/rig-core/src/vector_store/in_memory_store.rs index c50b5a22b..f7fc67f4e 100644 --- a/rig/rig-core/src/vector_store/in_memory_store.rs +++ b/rig/rig-core/src/vector_store/in_memory_store.rs @@ -194,12 +194,10 @@ impl InMemoryVectorStore { _num_hyperplanes: usize, ) -> EmbeddingRanking<'_, D> { // If we don't have an LSH index yet, fall back to brute force - if self.lsh_index.is_none() { + let Some(lsh_index) = self.lsh_index.as_ref() else { tracing::warn!("LSH index not initialized, falling back to brute force search"); return self.vector_search_brute_force(prompt_embedding, n); - } - - let lsh_index = self.lsh_index.as_ref().unwrap(); + }; let candidates = lsh_index.query(&prompt_embedding.vec); // Sort documents by best embedding distance, but only check candidates diff --git a/rig/rig-core/src/vector_store/lsh.rs b/rig/rig-core/src/vector_store/lsh.rs index 993bd2dc4..759c553d5 100644 --- a/rig/rig-core/src/vector_store/lsh.rs +++ b/rig/rig-core/src/vector_store/lsh.rs @@ -1,6 +1,16 @@ use fastrand::Rng; use std::collections::HashMap; +#[cfg(test)] +fn lsh_rng() -> Rng { + Rng::with_seed(0x5eed_fade_cafe_beef) +} + +#[cfg(not(test))] +fn lsh_rng() -> Rng { + Rng::new() +} + /// Locality Sensitive Hashing (LSH) with random projection. /// Uses random hyperplanes to hash similar vectors into the same buckets for efficient /// approximate nearest neighbor search. See @@ -15,7 +25,7 @@ pub struct LSH { impl LSH { /// Create a new LSH instance. pub fn new(dim: usize, num_tables: usize, num_hyperplanes: usize) -> Self { - let mut rng = Rng::new(); + let mut rng = lsh_rng(); let mut hyperplanes = Vec::new(); for _ in 0..(num_tables * num_hyperplanes) { @@ -53,7 +63,10 @@ impl LSH { let mut hash = 0u64; let start = table_idx * self.num_hyperplanes; - for (i, hyperplane) in self.hyperplanes[start..start + self.num_hyperplanes] + for (i, hyperplane) in self + .hyperplanes + .get(start..start + self.num_hyperplanes) + .unwrap_or(&[]) .iter() .enumerate() { @@ -96,10 +109,9 @@ impl LSHIndex { pub fn insert(&mut self, id: String, embedding: &[f64]) { for table_idx in 0..self.lsh.num_tables { let hash = self.lsh.hash(embedding, table_idx); - self.tables[table_idx] - .entry(hash) - .or_default() - .push(id.clone()); + if let Some(table) = self.tables.get_mut(table_idx) { + table.entry(hash).or_default().push(id.clone()); + } } } @@ -113,7 +125,11 @@ impl LSHIndex { for table_idx in 0..self.lsh.num_tables { let hash = self.lsh.hash(embedding, table_idx); - if let Some(ids) = self.tables[table_idx].get(&hash) { + if let Some(ids) = self + .tables + .get(table_idx) + .and_then(|table| table.get(&hash)) + { candidates.extend(ids.iter().cloned()); } } diff --git a/rig/rig-core/tests/anthropic.rs b/rig/rig-core/tests/anthropic.rs index 01022b92a..d5301ef3a 100644 --- a/rig/rig-core/tests/anthropic.rs +++ b/rig/rig-core/tests/anthropic.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + //! Anthropic integration tests. //! //! Run the full provider target with: diff --git a/rig/rig-core/tests/anthropic/agent.rs b/rig/rig-core/tests/anthropic/agent.rs index 469545ab4..5c73e3b82 100644 --- a/rig/rig-core/tests/anthropic/agent.rs +++ b/rig/rig-core/tests/anthropic/agent.rs @@ -9,7 +9,7 @@ use crate::support::{BASIC_PREAMBLE, BASIC_PROMPT, assert_nonempty_response}; #[tokio::test] #[ignore = "requires ANTHROPIC_API_KEY"] async fn completion_smoke() { - let client = anthropic::Client::from_env(); + let client = anthropic::Client::from_env().expect("client should build"); let agent = client .agent(anthropic::completion::CLAUDE_SONNET_4_6) .preamble(BASIC_PREAMBLE) diff --git a/rig/rig-core/tests/anthropic/default_max_turns.rs b/rig/rig-core/tests/anthropic/default_max_turns.rs index dbabf3222..3a58da712 100644 --- a/rig/rig-core/tests/anthropic/default_max_turns.rs +++ b/rig/rig-core/tests/anthropic/default_max_turns.rs @@ -82,6 +82,7 @@ impl Tool for Divide { #[ignore = "requires ANTHROPIC_API_KEY"] async fn default_max_turns_allows_multi_step_tool_use() -> Result<()> { let agent = anthropic::Client::from_env() + .expect("client should build") .agent(anthropic::completion::CLAUDE_SONNET_4_6) .preamble( "You are an assistant that must use the available tools for arithmetic. \ diff --git a/rig/rig-core/tests/anthropic/empty_end_turn.rs b/rig/rig-core/tests/anthropic/empty_end_turn.rs index d6aeedb1d..f09754881 100644 --- a/rig/rig-core/tests/anthropic/empty_end_turn.rs +++ b/rig/rig-core/tests/anthropic/empty_end_turn.rs @@ -131,7 +131,9 @@ fn history_has_empty_assistant_text(messages: &[Message]) -> bool { #[tokio::test] #[ignore = "requires ANTHROPIC_API_KEY"] async fn raw_followup_empty_end_turn_normalizes_to_empty_text_choice() { - let model = anthropic::Client::from_env().completion_model(CLAUDE_SONNET_4_6); + let model = anthropic::Client::from_env() + .expect("client should build") + .completion_model(CLAUDE_SONNET_4_6); let first_turn = model .completion_request(TERMINAL_NOTIFY_PROMPT) @@ -189,6 +191,7 @@ async fn raw_followup_empty_end_turn_normalizes_to_empty_text_choice() { async fn prompt_loop_accepts_empty_terminal_turn_after_tool_result() { let call_count = Arc::new(AtomicUsize::new(0)); let agent = anthropic::Client::from_env() + .expect("client should build") .agent(CLAUDE_SONNET_4_6) .preamble(TERMINAL_NOTIFY_PREAMBLE) .max_tokens(1024) @@ -237,6 +240,7 @@ async fn prompt_loop_accepts_empty_terminal_turn_after_tool_result() { async fn prompt_loop_preserves_pre_tool_text_when_terminal_followup_is_empty() { let call_count = Arc::new(AtomicUsize::new(0)); let agent = anthropic::Client::from_env() + .expect("client should build") .agent(CLAUDE_SONNET_4_6) .preamble(TERMINAL_NOTIFY_WITH_ACK_PREAMBLE) .max_tokens(1024) diff --git a/rig/rig-core/tests/anthropic/image.rs b/rig/rig-core/tests/anthropic/image.rs index 40f12b420..de112a9a5 100644 --- a/rig/rig-core/tests/anthropic/image.rs +++ b/rig/rig-core/tests/anthropic/image.rs @@ -16,7 +16,7 @@ use crate::support::{ #[tokio::test] #[ignore = "requires ANTHROPIC_API_KEY"] async fn image_prompt_from_fixture() { - let client = anthropic::Client::from_env(); + let client = anthropic::Client::from_env().expect("client should build"); let agent = client .agent(anthropic::completion::CLAUDE_SONNET_4_6) .preamble("You are an image describer.") diff --git a/rig/rig-core/tests/anthropic/models.rs b/rig/rig-core/tests/anthropic/models.rs index 926f5b09a..2dbad2781 100644 --- a/rig/rig-core/tests/anthropic/models.rs +++ b/rig/rig-core/tests/anthropic/models.rs @@ -6,7 +6,7 @@ use rig::providers::anthropic; #[tokio::test] #[ignore = "requires ANTHROPIC_API_KEY"] async fn list_models_smoke() { - let client = anthropic::Client::from_env(); + let client = anthropic::Client::from_env().expect("client should build"); let models = match client.list_models().await { Ok(models) => models, Err(error) => { diff --git a/rig/rig-core/tests/anthropic/multi_turn_streaming.rs b/rig/rig-core/tests/anthropic/multi_turn_streaming.rs index b9cf56130..b32633064 100644 --- a/rig/rig-core/tests/anthropic/multi_turn_streaming.rs +++ b/rig/rig-core/tests/anthropic/multi_turn_streaming.rs @@ -158,7 +158,7 @@ async fn multi_turn_streaming_tools() { let multiply_calls = Arc::new(AtomicUsize::new(0)); let divide_calls = Arc::new(AtomicUsize::new(0)); - let client = anthropic::Client::from_env(); + let client = anthropic::Client::from_env().expect("client should build"); let agent = client .agent(anthropic::completion::CLAUDE_SONNET_4_6) .preamble("You must use tools for arithmetic.") diff --git a/rig/rig-core/tests/anthropic/plaintext_document.rs b/rig/rig-core/tests/anthropic/plaintext_document.rs index a2f12b5c1..0caf34e6e 100644 --- a/rig/rig-core/tests/anthropic/plaintext_document.rs +++ b/rig/rig-core/tests/anthropic/plaintext_document.rs @@ -28,7 +28,7 @@ Key Features: #[tokio::test] #[ignore = "requires ANTHROPIC_API_KEY"] async fn plaintext_document_prompt() { - let client = anthropic::Client::from_env(); + let client = anthropic::Client::from_env().expect("client should build"); let agent = client .agent(anthropic::completion::CLAUDE_SONNET_4_6) .preamble("You are a helpful assistant that analyzes documents.") @@ -52,7 +52,7 @@ async fn plaintext_document_prompt() { #[tokio::test] #[ignore = "requires ANTHROPIC_API_KEY"] async fn plaintext_document_with_instruction() { - let client = anthropic::Client::from_env(); + let client = anthropic::Client::from_env().expect("client should build"); let agent = client .agent(anthropic::completion::CLAUDE_SONNET_4_6) .preamble("You are a helpful assistant that analyzes documents.") diff --git a/rig/rig-core/tests/anthropic/reasoning_roundtrip.rs b/rig/rig-core/tests/anthropic/reasoning_roundtrip.rs index 534245b17..3b83613ac 100644 --- a/rig/rig-core/tests/anthropic/reasoning_roundtrip.rs +++ b/rig/rig-core/tests/anthropic/reasoning_roundtrip.rs @@ -11,7 +11,7 @@ use crate::reasoning::{self, ReasoningRoundtripAgent}; #[tokio::test] #[ignore = "requires ANTHROPIC_API_KEY"] async fn streaming() { - let client = anthropic::Client::from_env(); + let client = anthropic::Client::from_env().expect("client should build"); reasoning::run_reasoning_roundtrip_streaming(ReasoningRoundtripAgent::new( client.completion_model(CLAUDE_SONNET_4_6), Some(serde_json::json!({ @@ -24,7 +24,7 @@ async fn streaming() { #[tokio::test] #[ignore = "requires ANTHROPIC_API_KEY"] async fn nonstreaming() { - let client = anthropic::Client::from_env(); + let client = anthropic::Client::from_env().expect("client should build"); reasoning::run_reasoning_roundtrip_nonstreaming(ReasoningRoundtripAgent::new( client.completion_model(CLAUDE_SONNET_4_6), Some(serde_json::json!({ diff --git a/rig/rig-core/tests/anthropic/reasoning_tool_roundtrip.rs b/rig/rig-core/tests/anthropic/reasoning_tool_roundtrip.rs index db3e9646e..c9a6e80a9 100644 --- a/rig/rig-core/tests/anthropic/reasoning_tool_roundtrip.rs +++ b/rig/rig-core/tests/anthropic/reasoning_tool_roundtrip.rs @@ -17,7 +17,7 @@ use crate::reasoning::{self, WeatherTool}; #[ignore = "requires ANTHROPIC_API_KEY"] async fn streaming() { let call_count = Arc::new(AtomicUsize::new(0)); - let client = anthropic::Client::from_env(); + let client = anthropic::Client::from_env().expect("client should build"); let agent = client .agent(CLAUDE_SONNET_4_6) .preamble(reasoning::TOOL_SYSTEM_PROMPT) @@ -54,7 +54,7 @@ async fn streaming() { #[ignore = "requires ANTHROPIC_API_KEY"] async fn nonstreaming() { let call_count = Arc::new(AtomicUsize::new(0)); - let client = anthropic::Client::from_env(); + let client = anthropic::Client::from_env().expect("client should build"); let agent = client .agent(CLAUDE_SONNET_4_6) .preamble(reasoning::TOOL_SYSTEM_PROMPT) diff --git a/rig/rig-core/tests/anthropic/streaming.rs b/rig/rig-core/tests/anthropic/streaming.rs index 026586146..4333f57b4 100644 --- a/rig/rig-core/tests/anthropic/streaming.rs +++ b/rig/rig-core/tests/anthropic/streaming.rs @@ -11,7 +11,7 @@ use crate::support::{ #[tokio::test] #[ignore = "requires ANTHROPIC_API_KEY"] async fn streaming_smoke() { - let client = anthropic::Client::from_env(); + let client = anthropic::Client::from_env().expect("client should build"); let agent = client .agent(anthropic::completion::CLAUDE_SONNET_4_6) .preamble(STREAMING_PREAMBLE) diff --git a/rig/rig-core/tests/anthropic/streaming_tools.rs b/rig/rig-core/tests/anthropic/streaming_tools.rs index f2c6819f0..8a6414808 100644 --- a/rig/rig-core/tests/anthropic/streaming_tools.rs +++ b/rig/rig-core/tests/anthropic/streaming_tools.rs @@ -12,7 +12,7 @@ use crate::support::{ #[tokio::test] #[ignore = "requires ANTHROPIC_API_KEY"] async fn streaming_tools_smoke() { - let client = anthropic::Client::from_env(); + let client = anthropic::Client::from_env().expect("client should build"); let agent = client .agent(anthropic::completion::CLAUDE_SONNET_4_6) .preamble(STREAMING_TOOLS_PREAMBLE) diff --git a/rig/rig-core/tests/anthropic/structured_output.rs b/rig/rig-core/tests/anthropic/structured_output.rs index 2e992ed49..0b410d229 100644 --- a/rig/rig-core/tests/anthropic/structured_output.rs +++ b/rig/rig-core/tests/anthropic/structured_output.rs @@ -11,7 +11,7 @@ use crate::support::{ #[tokio::test] #[ignore = "requires ANTHROPIC_API_KEY"] async fn structured_output_smoke() { - let client = anthropic::Client::from_env(); + let client = anthropic::Client::from_env().expect("client should build"); let agent = client .agent(CLAUDE_SONNET_4_6) .output_schema::() diff --git a/rig/rig-core/tests/anthropic/think_tool.rs b/rig/rig-core/tests/anthropic/think_tool.rs index 4e441de2c..48089a426 100644 --- a/rig/rig-core/tests/anthropic/think_tool.rs +++ b/rig/rig-core/tests/anthropic/think_tool.rs @@ -11,6 +11,7 @@ use crate::support::{assert_contains_any_case_insensitive, assert_nonempty_respo #[ignore = "requires ANTHROPIC_API_KEY"] async fn think_tool_menu_planning() { let agent = anthropic::Client::from_env() + .expect("client should build") .agent(anthropic::completion::CLAUDE_SONNET_4_6) .name("Anthropic Thinker") .preamble( diff --git a/rig/rig-core/tests/azure.rs b/rig/rig-core/tests/azure.rs index 1e8f9150c..1f69028c3 100644 --- a/rig/rig-core/tests/azure.rs +++ b/rig/rig-core/tests/azure.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + //! Azure integration tests. //! //! Run the full provider target with: diff --git a/rig/rig-core/tests/azure/transcription.rs b/rig/rig-core/tests/azure/transcription.rs index b0824378d..82b260e0d 100644 --- a/rig/rig-core/tests/azure/transcription.rs +++ b/rig/rig-core/tests/azure/transcription.rs @@ -10,7 +10,7 @@ use crate::support::{AUDIO_FIXTURE_PATH, assert_nonempty_response}; #[tokio::test] #[ignore = "requires AZURE_OPENAI_API_KEY and related Azure env vars"] async fn transcription_smoke() { - let client = azure::Client::from_env(); + let client = azure::Client::from_env().expect("client should build"); let model = client.transcription_model("whisper"); let response = model .transcription_request() diff --git a/rig/rig-core/tests/chatgpt.rs b/rig/rig-core/tests/chatgpt.rs index 43daa113d..e19dd7f12 100644 --- a/rig/rig-core/tests/chatgpt.rs +++ b/rig/rig-core/tests/chatgpt.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + //! ChatGPT integration tests. //! //! Run the provider target with: diff --git a/rig/rig-core/tests/cohere.rs b/rig/rig-core/tests/cohere.rs index efa3d62e1..3f95f08cd 100644 --- a/rig/rig-core/tests/cohere.rs +++ b/rig/rig-core/tests/cohere.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + //! Cohere integration tests. //! //! Run the full provider target with: diff --git a/rig/rig-core/tests/cohere/agent.rs b/rig/rig-core/tests/cohere/agent.rs index 1707622d0..f890a0313 100644 --- a/rig/rig-core/tests/cohere/agent.rs +++ b/rig/rig-core/tests/cohere/agent.rs @@ -9,7 +9,7 @@ use crate::support::{BASIC_PREAMBLE, BASIC_PROMPT, assert_nonempty_response}; #[tokio::test] #[ignore = "requires COHERE_API_KEY"] async fn completion_smoke() { - let client = cohere::Client::from_env(); + let client = cohere::Client::from_env().expect("client should build"); let agent = client .agent(cohere::COMMAND_R) .preamble(BASIC_PREAMBLE) diff --git a/rig/rig-core/tests/cohere/streaming.rs b/rig/rig-core/tests/cohere/streaming.rs index 0ed55a3b2..d081ac394 100644 --- a/rig/rig-core/tests/cohere/streaming.rs +++ b/rig/rig-core/tests/cohere/streaming.rs @@ -11,7 +11,7 @@ use crate::support::{ #[tokio::test] #[ignore = "requires COHERE_API_KEY"] async fn streaming_smoke() { - let client = cohere::Client::from_env(); + let client = cohere::Client::from_env().expect("client should build"); let agent = client .agent(cohere::COMMAND) .preamble(STREAMING_PREAMBLE) diff --git a/rig/rig-core/tests/cohere/streaming_tools.rs b/rig/rig-core/tests/cohere/streaming_tools.rs index 2cb5989db..012fcf3bc 100644 --- a/rig/rig-core/tests/cohere/streaming_tools.rs +++ b/rig/rig-core/tests/cohere/streaming_tools.rs @@ -12,7 +12,7 @@ use crate::support::{ #[tokio::test] #[ignore = "requires COHERE_API_KEY"] async fn streaming_tools_smoke() { - let client = cohere::Client::from_env(); + let client = cohere::Client::from_env().expect("client should build"); let agent = client .agent(cohere::COMMAND_R) .preamble(STREAMING_TOOLS_PREAMBLE) diff --git a/rig/rig-core/tests/cohere/tools.rs b/rig/rig-core/tests/cohere/tools.rs index bacc688c7..9d8d8010e 100644 --- a/rig/rig-core/tests/cohere/tools.rs +++ b/rig/rig-core/tests/cohere/tools.rs @@ -11,7 +11,7 @@ use crate::support::{ #[tokio::test] #[ignore = "requires COHERE_API_KEY"] async fn tools_smoke() { - let client = cohere::Client::from_env(); + let client = cohere::Client::from_env().expect("client should build"); let agent = client .agent(cohere::COMMAND_R) .preamble(TOOLS_PREAMBLE) diff --git a/rig/rig-core/tests/copilot.rs b/rig/rig-core/tests/copilot.rs index 5d97bcf0e..74339ceaa 100644 --- a/rig/rig-core/tests/copilot.rs +++ b/rig/rig-core/tests/copilot.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + //! Copilot integration tests. //! //! Run the provider target with: diff --git a/rig/rig-core/tests/core.rs b/rig/rig-core/tests/core.rs index 88d1e2eab..807574004 100644 --- a/rig/rig-core/tests/core.rs +++ b/rig/rig-core/tests/core.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + //! Core integration tests that are not provider-specific. //! //! Run the target with: diff --git a/rig/rig-core/tests/deepseek.rs b/rig/rig-core/tests/deepseek.rs index 18dab7eb3..172b65bec 100644 --- a/rig/rig-core/tests/deepseek.rs +++ b/rig/rig-core/tests/deepseek.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + //! DeepSeek integration tests. //! //! Run the full provider target with: diff --git a/rig/rig-core/tests/deepseek/agent.rs b/rig/rig-core/tests/deepseek/agent.rs index 142753cd8..2c808582a 100644 --- a/rig/rig-core/tests/deepseek/agent.rs +++ b/rig/rig-core/tests/deepseek/agent.rs @@ -9,7 +9,7 @@ use crate::support::{BASIC_PREAMBLE, BASIC_PROMPT, assert_nonempty_response}; #[tokio::test] #[ignore = "requires DEEPSEEK_API_KEY"] async fn completion_smoke() { - let client = deepseek::Client::from_env(); + let client = deepseek::Client::from_env().expect("client should build"); let agent = client .agent(deepseek::DEEPSEEK_CHAT) .preamble(BASIC_PREAMBLE) diff --git a/rig/rig-core/tests/deepseek/extractor.rs b/rig/rig-core/tests/deepseek/extractor.rs index 9d00418dd..1549c765a 100644 --- a/rig/rig-core/tests/deepseek/extractor.rs +++ b/rig/rig-core/tests/deepseek/extractor.rs @@ -8,7 +8,7 @@ use crate::support::{EXTRACTOR_TEXT, SmokePerson, assert_nonempty_response}; #[tokio::test] #[ignore = "requires DEEPSEEK_API_KEY"] async fn extractor_smoke() { - let client = deepseek::Client::from_env(); + let client = deepseek::Client::from_env().expect("client should build"); let extractor = client .extractor::(deepseek::DEEPSEEK_CHAT) .build(); diff --git a/rig/rig-core/tests/deepseek/extractor_usage.rs b/rig/rig-core/tests/deepseek/extractor_usage.rs index fe7fc8d4c..e799c527e 100644 --- a/rig/rig-core/tests/deepseek/extractor_usage.rs +++ b/rig/rig-core/tests/deepseek/extractor_usage.rs @@ -39,7 +39,7 @@ fn assert_compatible_professions(left: Option<&str>, right: &str) { #[tokio::test] #[ignore = "requires DEEPSEEK_API_KEY"] async fn extract_backward_compatibility() -> Result<()> { - let client = deepseek::Client::from_env(); + let client = deepseek::Client::from_env().expect("client should build"); let extractor = client.extractor::(deepseek::DEEPSEEK_CHAT).build(); let person = extractor @@ -56,7 +56,7 @@ async fn extract_backward_compatibility() -> Result<()> { #[tokio::test] #[ignore = "requires DEEPSEEK_API_KEY"] async fn extract_with_usage_returns_data_and_usage() -> Result<()> { - let client = deepseek::Client::from_env(); + let client = deepseek::Client::from_env().expect("client should build"); let extractor = client.extractor::(deepseek::DEEPSEEK_CHAT).build(); let response: ExtractionResponse = extractor @@ -76,7 +76,7 @@ async fn extract_with_usage_returns_data_and_usage() -> Result<()> { #[tokio::test] #[ignore = "requires DEEPSEEK_API_KEY"] async fn extract_with_chat_history_with_usage_works() -> Result<()> { - let client = deepseek::Client::from_env(); + let client = deepseek::Client::from_env().expect("client should build"); let extractor = client.extractor::
(deepseek::DEEPSEEK_CHAT).build(); let chat_history = vec![Message::user( @@ -103,7 +103,7 @@ async fn extract_with_chat_history_with_usage_works() -> Result<()> { #[tokio::test] #[ignore = "requires DEEPSEEK_API_KEY"] async fn extract_and_extract_with_usage_return_same_data() -> Result<()> { - let client = deepseek::Client::from_env(); + let client = deepseek::Client::from_env().expect("client should build"); let extractor = client.extractor::(deepseek::DEEPSEEK_CHAT).build(); let text = "Bob Johnson is a 55 year old retired teacher."; @@ -124,7 +124,7 @@ async fn extract_and_extract_with_usage_return_same_data() -> Result<()> { #[tokio::test] #[ignore = "requires DEEPSEEK_API_KEY"] async fn usage_tracking_works_for_different_schemas() -> Result<()> { - let client = deepseek::Client::from_env(); + let client = deepseek::Client::from_env().expect("client should build"); let person_extractor = client.extractor::(deepseek::DEEPSEEK_CHAT).build(); let person_response = person_extractor diff --git a/rig/rig-core/tests/deepseek/multi_extract.rs b/rig/rig-core/tests/deepseek/multi_extract.rs index 5d7fe8201..9a5cdeba4 100644 --- a/rig/rig-core/tests/deepseek/multi_extract.rs +++ b/rig/rig-core/tests/deepseek/multi_extract.rs @@ -29,7 +29,7 @@ struct Sentiment { #[tokio::test] #[ignore = "requires DEEPSEEK_API_KEY"] async fn batch_multi_extract_chain() -> Result<()> { - let client = deepseek::Client::from_env(); + let client = deepseek::Client::from_env().expect("client should build"); let names_extractor = client .extractor::(deepseek::DEEPSEEK_CHAT) .preamble("Extract names from the given text.") diff --git a/rig/rig-core/tests/deepseek/permission_control.rs b/rig/rig-core/tests/deepseek/permission_control.rs index 037868104..714683d73 100644 --- a/rig/rig-core/tests/deepseek/permission_control.rs +++ b/rig/rig-core/tests/deepseek/permission_control.rs @@ -154,6 +154,7 @@ async fn permission_control_prompt_example() -> Result<()> { let _cleanup = FileCleanup::new()?; let agent = deepseek::Client::from_env() + .expect("client should build") .agent(deepseek::DEEPSEEK_CHAT) .preamble("You are a helpful assistant that can read files using different methods.") .tool(ReadFileHead) @@ -189,6 +190,7 @@ async fn permission_control_streaming_example() -> Result<()> { let _cleanup = FileCleanup::new()?; let agent = deepseek::Client::from_env() + .expect("client should build") .agent(deepseek::DEEPSEEK_CHAT) .preamble("You are a helpful assistant that can read files using different methods.") .tool(ReadFileHead) diff --git a/rig/rig-core/tests/deepseek/reasoning_roundtrip.rs b/rig/rig-core/tests/deepseek/reasoning_roundtrip.rs index 97d200fc5..c4c239c63 100644 --- a/rig/rig-core/tests/deepseek/reasoning_roundtrip.rs +++ b/rig/rig-core/tests/deepseek/reasoning_roundtrip.rs @@ -8,7 +8,7 @@ use crate::reasoning::{self, ReasoningRoundtripAgent}; #[tokio::test] #[ignore = "requires DEEPSEEK_API_KEY"] async fn streaming() { - let client = deepseek::Client::from_env(); + let client = deepseek::Client::from_env().expect("client should build"); reasoning::run_reasoning_roundtrip_streaming(ReasoningRoundtripAgent::new( client.completion_model(deepseek::DEEPSEEK_REASONER), None, @@ -19,7 +19,7 @@ async fn streaming() { #[tokio::test] #[ignore = "requires DEEPSEEK_API_KEY"] async fn nonstreaming() { - let client = deepseek::Client::from_env(); + let client = deepseek::Client::from_env().expect("client should build"); reasoning::run_reasoning_roundtrip_nonstreaming(ReasoningRoundtripAgent::new( client.completion_model(deepseek::DEEPSEEK_REASONER), None, diff --git a/rig/rig-core/tests/deepseek/reasoning_tool_roundtrip.rs b/rig/rig-core/tests/deepseek/reasoning_tool_roundtrip.rs index 877d727c7..88c26df1a 100644 --- a/rig/rig-core/tests/deepseek/reasoning_tool_roundtrip.rs +++ b/rig/rig-core/tests/deepseek/reasoning_tool_roundtrip.rs @@ -14,7 +14,7 @@ use crate::reasoning::{self, WeatherTool}; #[ignore = "requires DEEPSEEK_API_KEY"] async fn streaming() { let call_count = Arc::new(AtomicUsize::new(0)); - let client = deepseek::Client::from_env(); + let client = deepseek::Client::from_env().expect("client should build"); let agent = client .agent(deepseek::DEEPSEEK_REASONER) .preamble(reasoning::TOOL_SYSTEM_PROMPT) @@ -43,7 +43,7 @@ async fn streaming() { #[ignore = "requires DEEPSEEK_API_KEY"] async fn nonstreaming() { let call_count = Arc::new(AtomicUsize::new(0)); - let client = deepseek::Client::from_env(); + let client = deepseek::Client::from_env().expect("client should build"); let agent = client .agent(deepseek::DEEPSEEK_REASONER) .preamble(reasoning::TOOL_SYSTEM_PROMPT) diff --git a/rig/rig-core/tests/deepseek/request_hook.rs b/rig/rig-core/tests/deepseek/request_hook.rs index bea0c1bd0..ebc78cfd9 100644 --- a/rig/rig-core/tests/deepseek/request_hook.rs +++ b/rig/rig-core/tests/deepseek/request_hook.rs @@ -69,6 +69,7 @@ where #[ignore = "requires DEEPSEEK_API_KEY"] async fn request_hook_records_prompt_and_response() -> Result<()> { let agent = deepseek::Client::from_env() + .expect("client should build") .agent(deepseek::DEEPSEEK_CHAT) .preamble("You are a comedian here to entertain the user using humour and jokes.") .build(); diff --git a/rig/rig-core/tests/deepseek/streaming.rs b/rig/rig-core/tests/deepseek/streaming.rs index 8c011b265..1a7db9b25 100644 --- a/rig/rig-core/tests/deepseek/streaming.rs +++ b/rig/rig-core/tests/deepseek/streaming.rs @@ -9,7 +9,7 @@ use crate::support::{assert_nonempty_response, collect_stream_final_response}; #[tokio::test] #[ignore = "requires DEEPSEEK_API_KEY"] async fn streaming_prompt_smoke() { - let client = deepseek::Client::from_env(); + let client = deepseek::Client::from_env().expect("client should build"); let agent = client .agent(DEEPSEEK_CHAT) .preamble("You are a helpful assistant.") diff --git a/rig/rig-core/tests/deepseek/streaming_tools.rs b/rig/rig-core/tests/deepseek/streaming_tools.rs index 37d156dc2..5f6954db0 100644 --- a/rig/rig-core/tests/deepseek/streaming_tools.rs +++ b/rig/rig-core/tests/deepseek/streaming_tools.rs @@ -22,7 +22,7 @@ use crate::support::{ #[tokio::test] #[ignore = "requires DEEPSEEK_API_KEY"] async fn streaming_chat_with_tools() { - let client = deepseek::Client::from_env(); + let client = deepseek::Client::from_env().expect("client should build"); let agent = client .agent(DEEPSEEK_CHAT) .preamble("You are a calculator here to help the user perform arithmetic operations.") @@ -43,7 +43,7 @@ async fn streaming_chat_with_tools() { #[tokio::test] #[ignore = "requires DEEPSEEK_API_KEY"] async fn raw_stream_emits_required_zero_arg_tool_call() { - let client = deepseek::Client::from_env(); + let client = deepseek::Client::from_env().expect("client should build"); let model = client.completion_model(DEEPSEEK_CHAT); let request = model .completion_request(REQUIRED_ZERO_ARG_TOOL_PROMPT) @@ -58,7 +58,7 @@ async fn raw_stream_emits_required_zero_arg_tool_call() { #[tokio::test] #[ignore = "requires DEEPSEEK_API_KEY"] async fn raw_stream_surfaces_two_distinct_tool_calls_before_text() { - let client = deepseek::Client::from_env(); + let client = deepseek::Client::from_env().expect("client should build"); let model = client.completion_model(DEEPSEEK_CHAT); let request = model .completion_request(TWO_TOOL_STREAM_PROMPT) @@ -84,7 +84,7 @@ async fn raw_stream_surfaces_two_distinct_tool_calls_before_text() { #[tokio::test] #[ignore = "requires DEEPSEEK_API_KEY"] async fn streaming_chat_surfaces_two_distinct_tool_calls_before_final_answer() { - let client = deepseek::Client::from_env(); + let client = deepseek::Client::from_env().expect("client should build"); let agent = client .agent(DEEPSEEK_CHAT) .preamble(TWO_TOOL_STREAM_PREAMBLE) @@ -109,7 +109,7 @@ async fn streaming_chat_surfaces_two_distinct_tool_calls_before_final_answer() { #[tokio::test] #[ignore = "requires DEEPSEEK_API_KEY"] async fn streaming_chat_emits_tool_call_before_later_text() { - let client = deepseek::Client::from_env(); + let client = deepseek::Client::from_env().expect("client should build"); let agent = client .agent(DEEPSEEK_CHAT) .preamble(ORDERED_TOOL_STREAM_PREAMBLE) @@ -133,7 +133,7 @@ async fn streaming_chat_emits_tool_call_before_later_text() { #[tokio::test] #[ignore = "requires DEEPSEEK_API_KEY"] async fn raw_followup_uses_tool_result_without_new_tool_calls() { - let client = deepseek::Client::from_env(); + let client = deepseek::Client::from_env().expect("client should build"); let model = client.completion_model(DEEPSEEK_CHAT); let request = model .completion_request(ORDERED_TOOL_STREAM_PROMPT) diff --git a/rig/rig-core/tests/deepseek/tools.rs b/rig/rig-core/tests/deepseek/tools.rs index 27dfe6290..5e94eb3dc 100644 --- a/rig/rig-core/tests/deepseek/tools.rs +++ b/rig/rig-core/tests/deepseek/tools.rs @@ -11,7 +11,7 @@ use crate::support::{ #[tokio::test] #[ignore = "requires DEEPSEEK_API_KEY"] async fn tools_smoke() { - let client = deepseek::Client::from_env(); + let client = deepseek::Client::from_env().expect("client should build"); let agent = client .agent(deepseek::DEEPSEEK_CHAT) .preamble(TOOLS_PREAMBLE) diff --git a/rig/rig-core/tests/galadriel.rs b/rig/rig-core/tests/galadriel.rs index bb67ccde4..36411b017 100644 --- a/rig/rig-core/tests/galadriel.rs +++ b/rig/rig-core/tests/galadriel.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + //! Galadriel integration tests. //! //! Run the provider target with: diff --git a/rig/rig-core/tests/galadriel/agent.rs b/rig/rig-core/tests/galadriel/agent.rs index 91f59fdce..f98da27ba 100644 --- a/rig/rig-core/tests/galadriel/agent.rs +++ b/rig/rig-core/tests/galadriel/agent.rs @@ -9,7 +9,7 @@ use crate::support::{BASIC_PREAMBLE, BASIC_PROMPT, assert_nonempty_response}; #[tokio::test] #[ignore = "requires GALADRIEL_API_KEY"] async fn completion_smoke() { - let client = galadriel::Client::from_env(); + let client = galadriel::Client::from_env().expect("galadriel client should build"); let agent = client .agent(galadriel::GPT_4O) .preamble(BASIC_PREAMBLE) diff --git a/rig/rig-core/tests/galadriel/streaming_tools.rs b/rig/rig-core/tests/galadriel/streaming_tools.rs index 7150028bc..c5721b76b 100644 --- a/rig/rig-core/tests/galadriel/streaming_tools.rs +++ b/rig/rig-core/tests/galadriel/streaming_tools.rs @@ -13,7 +13,7 @@ use crate::support::{ #[tokio::test] #[ignore = "requires GALADRIEL_API_KEY"] async fn raw_stream_emits_required_zero_arg_tool_call() { - let client = galadriel::Client::from_env(); + let client = galadriel::Client::from_env().expect("galadriel client should build"); let model = client.completion_model(galadriel::GPT_4O); let request = model .completion_request(REQUIRED_ZERO_ARG_TOOL_PROMPT) diff --git a/rig/rig-core/tests/gemini.rs b/rig/rig-core/tests/gemini.rs index d30065972..52072e140 100644 --- a/rig/rig-core/tests/gemini.rs +++ b/rig/rig-core/tests/gemini.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + //! Gemini integration tests. //! //! Run the full provider target with: diff --git a/rig/rig-core/tests/gemini/agent.rs b/rig/rig-core/tests/gemini/agent.rs index 285bc1e6a..590b631d6 100644 --- a/rig/rig-core/tests/gemini/agent.rs +++ b/rig/rig-core/tests/gemini/agent.rs @@ -9,7 +9,7 @@ use crate::support::{BASIC_PREAMBLE, BASIC_PROMPT, assert_nonempty_response}; #[tokio::test] #[ignore = "requires GEMINI_API_KEY"] async fn completion_smoke() { - let client = gemini::Client::from_env(); + let client = gemini::Client::from_env().expect("client should build"); let agent = client .agent(gemini::completion::GEMINI_2_5_FLASH) .preamble(BASIC_PREAMBLE) diff --git a/rig/rig-core/tests/gemini/embeddings.rs b/rig/rig-core/tests/gemini/embeddings.rs index cc7b71309..82c23dad0 100644 --- a/rig/rig-core/tests/gemini/embeddings.rs +++ b/rig/rig-core/tests/gemini/embeddings.rs @@ -18,7 +18,7 @@ struct Greetings { #[tokio::test] #[ignore = "requires GEMINI_API_KEY"] async fn embeddings_smoke() { - let client = gemini::Client::from_env(); + let client = gemini::Client::from_env().expect("client should build"); let model = client.embedding_model(gemini::embedding::EMBEDDING_001); let embeddings = model @@ -33,7 +33,7 @@ async fn embeddings_smoke() { #[tokio::test] #[ignore = "requires GEMINI_API_KEY and --features derive"] async fn derive_document_embeddings() { - let client = gemini::Client::from_env(); + let client = gemini::Client::from_env().expect("client should build"); let embeddings = client .embeddings(gemini::embedding::EMBEDDING_001) .document(Greetings { diff --git a/rig/rig-core/tests/gemini/extractor.rs b/rig/rig-core/tests/gemini/extractor.rs index 9ff7d8f1a..15a892031 100644 --- a/rig/rig-core/tests/gemini/extractor.rs +++ b/rig/rig-core/tests/gemini/extractor.rs @@ -23,7 +23,7 @@ async fn extractor_smoke() { let additional_params = AdditionalParameters::default().with_config(GenerationConfig::default()); - let client = gemini::Client::from_env(); + let client = gemini::Client::from_env().expect("client should build"); let extractor = client .extractor::(gemini::completion::GEMINI_2_5_FLASH) .additional_params( @@ -56,7 +56,7 @@ async fn extractor_smoke() { #[ignore = "requires GEMINI_API_KEY"] async fn extractor_with_additional_params() { let params = AdditionalParameters::default().with_config(GenerationConfig::default()); - let client = gemini::Client::from_env(); + let client = gemini::Client::from_env().expect("client should build"); let extractor = client .extractor::(gemini::completion::GEMINI_2_5_FLASH) .additional_params(serde_json::to_value(params).expect("params should serialize")) diff --git a/rig/rig-core/tests/gemini/interactions_api.rs b/rig/rig-core/tests/gemini/interactions_api.rs index 6d45e4ae3..5e6a592a0 100644 --- a/rig/rig-core/tests/gemini/interactions_api.rs +++ b/rig/rig-core/tests/gemini/interactions_api.rs @@ -32,7 +32,9 @@ fn first_tool_call(choice: &OneOrMany) -> Option { #[tokio::test] #[ignore = "requires GEMINI_API_KEY"] async fn basic_interaction_returns_id() { - let model = gemini::InteractionsClient::from_env().completion_model("gemini-3-flash-preview"); + let model = gemini::InteractionsClient::from_env() + .expect("client should build") + .completion_model("gemini-3-flash-preview"); let params = AdditionalParameters { store: Some(true), ..Default::default() @@ -57,7 +59,9 @@ async fn basic_interaction_returns_id() { #[tokio::test] #[ignore = "requires GEMINI_API_KEY"] async fn followup_with_previous_interaction_id() { - let model = gemini::InteractionsClient::from_env().completion_model("gemini-3-flash-preview"); + let model = gemini::InteractionsClient::from_env() + .expect("client should build") + .completion_model("gemini-3-flash-preview"); let initial = model .completion( model @@ -98,7 +102,9 @@ async fn followup_with_previous_interaction_id() { #[tokio::test] #[ignore = "requires GEMINI_API_KEY"] async fn google_search_tool_interaction() { - let model = gemini::InteractionsClient::from_env().completion_model("gemini-3-flash-preview"); + let model = gemini::InteractionsClient::from_env() + .expect("client should build") + .completion_model("gemini-3-flash-preview"); let response = model .completion( model @@ -125,7 +131,9 @@ async fn google_search_tool_interaction() { #[tokio::test] #[ignore = "requires GEMINI_API_KEY"] async fn tool_result_roundtrip() { - let model = gemini::InteractionsClient::from_env().completion_model("gemini-3-flash-preview"); + let model = gemini::InteractionsClient::from_env() + .expect("client should build") + .completion_model("gemini-3-flash-preview"); let tool = rig::completion::ToolDefinition { name: "add".to_string(), description: "Add two numbers together".to_string(), @@ -191,7 +199,9 @@ async fn tool_result_roundtrip() { #[tokio::test] #[ignore = "requires GEMINI_API_KEY"] async fn streaming_interaction() { - let model = gemini::InteractionsClient::from_env().completion_model("gemini-3-flash-preview"); + let model = gemini::InteractionsClient::from_env() + .expect("client should build") + .completion_model("gemini-3-flash-preview"); let request = model .completion_request("Write a 3-line poem about rust and rivers.") .temperature(0.4) diff --git a/rig/rig-core/tests/gemini/models.rs b/rig/rig-core/tests/gemini/models.rs index 177f154ca..3abc19c1a 100644 --- a/rig/rig-core/tests/gemini/models.rs +++ b/rig/rig-core/tests/gemini/models.rs @@ -6,13 +6,11 @@ use rig::providers::gemini; #[tokio::test] #[ignore = "requires GEMINI_API_KEY"] async fn list_models_smoke() { - let client = gemini::Client::from_env(); - let models = match client.list_models().await { - Ok(models) => models, - Err(error) => { - panic!("listing Gemini models should succeed\nDisplay: {error}\nDebug: {error:#?}") - } - }; + let client = gemini::Client::from_env().expect("client should build"); + let models = client + .list_models() + .await + .expect("listing Gemini models should succeed"); println!("Gemini returned {} models", models.len()); diff --git a/rig/rig-core/tests/gemini/multi_turn_streaming.rs b/rig/rig-core/tests/gemini/multi_turn_streaming.rs index 4e6775a9e..5172b6995 100644 --- a/rig/rig-core/tests/gemini/multi_turn_streaming.rs +++ b/rig/rig-core/tests/gemini/multi_turn_streaming.rs @@ -42,7 +42,7 @@ async fn manual_multi_turn_streaming_loop() { let multiply_calls = Arc::new(AtomicUsize::new(0)); let divide_calls = Arc::new(AtomicUsize::new(0)); - let client = gemini::Client::from_env(); + let client = gemini::Client::from_env().expect("client should build"); let agent = client .agent(gemini::completion::GEMINI_2_5_FLASH) .preamble("You must use tools to answer arithmetic prompts.") diff --git a/rig/rig-core/tests/gemini/reasoning_roundtrip.rs b/rig/rig-core/tests/gemini/reasoning_roundtrip.rs index 151cc63ec..29db0a855 100644 --- a/rig/rig-core/tests/gemini/reasoning_roundtrip.rs +++ b/rig/rig-core/tests/gemini/reasoning_roundtrip.rs @@ -11,7 +11,7 @@ use crate::reasoning::{self, ReasoningRoundtripAgent}; #[tokio::test] #[ignore = "requires GEMINI_API_KEY"] async fn streaming() { - let client = gemini::Client::from_env(); + let client = gemini::Client::from_env().expect("client should build"); reasoning::run_reasoning_roundtrip_streaming(ReasoningRoundtripAgent::new( client.completion_model("gemini-2.5-flash"), Some(serde_json::json!({ @@ -26,7 +26,7 @@ async fn streaming() { #[tokio::test] #[ignore = "requires GEMINI_API_KEY"] async fn nonstreaming() { - let client = gemini::Client::from_env(); + let client = gemini::Client::from_env().expect("client should build"); reasoning::run_reasoning_roundtrip_nonstreaming(ReasoningRoundtripAgent::new( client.completion_model("gemini-2.5-flash"), Some(serde_json::json!({ diff --git a/rig/rig-core/tests/gemini/reasoning_tool_roundtrip.rs b/rig/rig-core/tests/gemini/reasoning_tool_roundtrip.rs index f807969b8..43d63f04a 100644 --- a/rig/rig-core/tests/gemini/reasoning_tool_roundtrip.rs +++ b/rig/rig-core/tests/gemini/reasoning_tool_roundtrip.rs @@ -17,7 +17,7 @@ use crate::reasoning::{self, WeatherTool}; #[ignore = "requires GEMINI_API_KEY"] async fn streaming() { let call_count = Arc::new(AtomicUsize::new(0)); - let client = gemini::Client::from_env(); + let client = gemini::Client::from_env().expect("client should build"); let agent = client .agent("gemini-2.5-flash") .preamble(reasoning::TOOL_SYSTEM_PROMPT) @@ -43,7 +43,7 @@ async fn streaming() { #[ignore = "requires GEMINI_API_KEY"] async fn nonstreaming() { let call_count = Arc::new(AtomicUsize::new(0)); - let client = gemini::Client::from_env(); + let client = gemini::Client::from_env().expect("client should build"); let agent = client .agent("gemini-2.5-flash") .preamble(reasoning::TOOL_SYSTEM_PROMPT) diff --git a/rig/rig-core/tests/gemini/streaming.rs b/rig/rig-core/tests/gemini/streaming.rs index 8c10459ee..b3a833394 100644 --- a/rig/rig-core/tests/gemini/streaming.rs +++ b/rig/rig-core/tests/gemini/streaming.rs @@ -24,7 +24,7 @@ async fn streaming_smoke() { }; let additional_params = AdditionalParameters::default().with_config(thinking_config); - let client = gemini::Client::from_env(); + let client = gemini::Client::from_env().expect("client should build"); let agent = client .agent(gemini::completion::GEMINI_3_FLASH_PREVIEW) .preamble(STREAMING_PREAMBLE) @@ -55,6 +55,7 @@ async fn example_streaming_prompt() { }; let params = AdditionalParameters::default().with_config(generation_config); let agent = gemini::Client::from_env() + .expect("client should build") .agent(gemini::completion::GEMINI_3_FLASH_PREVIEW) .preamble("Be precise and concise.") .temperature(0.5) diff --git a/rig/rig-core/tests/gemini/streaming_multimodal_tool_results.rs b/rig/rig-core/tests/gemini/streaming_multimodal_tool_results.rs index 3973b9fdf..bcbdd2777 100644 --- a/rig/rig-core/tests/gemini/streaming_multimodal_tool_results.rs +++ b/rig/rig-core/tests/gemini/streaming_multimodal_tool_results.rs @@ -71,7 +71,7 @@ impl Tool for HybridImageTool { #[tokio::test] #[ignore = "requires GEMINI_API_KEY"] async fn streaming_history_preserves_hybrid_tool_result_image_parts() { - let client = gemini::Client::from_env(); + let client = gemini::Client::from_env().expect("client should build"); let agent = client .agent(MULTIMODAL_FUNCTION_RESPONSE_MODEL) .preamble( diff --git a/rig/rig-core/tests/gemini/streaming_tools.rs b/rig/rig-core/tests/gemini/streaming_tools.rs index 22db22190..e3e7044d0 100644 --- a/rig/rig-core/tests/gemini/streaming_tools.rs +++ b/rig/rig-core/tests/gemini/streaming_tools.rs @@ -27,7 +27,7 @@ fn streaming_tool_params() -> serde_json::Value { #[tokio::test] #[ignore = "requires GEMINI_API_KEY"] async fn streaming_tools_smoke() { - let client = gemini::Client::from_env(); + let client = gemini::Client::from_env().expect("client should build"); let agent = client .agent(gemini::completion::GEMINI_2_5_FLASH) .preamble(STREAMING_TOOLS_PREAMBLE) @@ -47,7 +47,7 @@ async fn streaming_tools_smoke() { #[tokio::test] #[ignore = "requires GEMINI_API_KEY"] async fn raw_stream_emits_required_zero_arg_tool_call() { - let client = gemini::Client::from_env(); + let client = gemini::Client::from_env().expect("client should build"); let model = client.completion_model(gemini::completion::GEMINI_2_5_FLASH); let request = model .completion_request(REQUIRED_ZERO_ARG_TOOL_PROMPT) @@ -63,7 +63,7 @@ async fn raw_stream_emits_required_zero_arg_tool_call() { #[tokio::test] #[ignore = "requires GEMINI_API_KEY"] async fn streaming_tools_surface_two_distinct_tool_calls_before_final_answer() { - let client = gemini::Client::from_env(); + let client = gemini::Client::from_env().expect("client should build"); let agent = client .agent(gemini::completion::GEMINI_2_5_FLASH) .preamble(TWO_TOOL_STREAM_PREAMBLE) @@ -88,7 +88,7 @@ async fn streaming_tools_surface_two_distinct_tool_calls_before_final_answer() { #[tokio::test] #[ignore = "requires GEMINI_API_KEY"] async fn streaming_tools_emit_tool_call_before_later_text() { - let client = gemini::Client::from_env(); + let client = gemini::Client::from_env().expect("client should build"); let agent = client .agent(gemini::completion::GEMINI_2_5_FLASH) .preamble(ORDERED_TOOL_STREAM_PREAMBLE) @@ -113,6 +113,7 @@ async fn streaming_tools_emit_tool_call_before_later_text() { #[ignore = "requires GEMINI_API_KEY"] async fn example_streaming_with_tools() { let agent = gemini::Client::from_env() + .expect("client should build") .agent(gemini::completion::GEMINI_2_5_FLASH) .preamble( "You are a calculator here to help the user perform arithmetic operations. \ diff --git a/rig/rig-core/tests/gemini/structured_output.rs b/rig/rig-core/tests/gemini/structured_output.rs index fa803f13f..1dcc96e3a 100644 --- a/rig/rig-core/tests/gemini/structured_output.rs +++ b/rig/rig-core/tests/gemini/structured_output.rs @@ -11,7 +11,7 @@ use crate::support::{ #[tokio::test] #[ignore = "requires GEMINI_API_KEY"] async fn structured_output_smoke() { - let client = gemini::Client::from_env(); + let client = gemini::Client::from_env().expect("client should build"); let agent = client .agent("gemini-3-flash-preview") .output_schema::() diff --git a/rig/rig-core/tests/gemini/transcription.rs b/rig/rig-core/tests/gemini/transcription.rs index 7ef14d065..f388057ef 100644 --- a/rig/rig-core/tests/gemini/transcription.rs +++ b/rig/rig-core/tests/gemini/transcription.rs @@ -10,7 +10,7 @@ use crate::support::{AUDIO_FIXTURE_PATH, assert_nonempty_response}; #[tokio::test] #[ignore = "requires GEMINI_API_KEY"] async fn transcription_smoke() { - let client = gemini::Client::from_env(); + let client = gemini::Client::from_env().expect("client should build"); let model = client.transcription_model(gemini::completion::GEMINI_3_FLASH_PREVIEW); let response = model .transcription_request() diff --git a/rig/rig-core/tests/groq.rs b/rig/rig-core/tests/groq.rs index 61bb1dd79..16a89f540 100644 --- a/rig/rig-core/tests/groq.rs +++ b/rig/rig-core/tests/groq.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + //! Groq integration tests. //! //! Run the provider target with: diff --git a/rig/rig-core/tests/groq/agent.rs b/rig/rig-core/tests/groq/agent.rs index c046eb5a8..571373e3a 100644 --- a/rig/rig-core/tests/groq/agent.rs +++ b/rig/rig-core/tests/groq/agent.rs @@ -11,7 +11,7 @@ use super::AGENT_MODEL; #[tokio::test] #[ignore = "requires GROQ_API_KEY"] async fn completion_smoke() { - let client = groq::Client::from_env(); + let client = groq::Client::from_env().expect("client should build"); let agent = client.agent(AGENT_MODEL).preamble(BASIC_PREAMBLE).build(); let response = agent diff --git a/rig/rig-core/tests/groq/context.rs b/rig/rig-core/tests/groq/context.rs index bae7b6451..f825751ac 100644 --- a/rig/rig-core/tests/groq/context.rs +++ b/rig/rig-core/tests/groq/context.rs @@ -11,7 +11,7 @@ use super::CONTEXT_MODEL; #[tokio::test] #[ignore = "requires GROQ_API_KEY"] async fn context_smoke() { - let client = groq::Client::from_env(); + let client = groq::Client::from_env().expect("client should build"); let agent = CONTEXT_DOCS .iter() .copied() diff --git a/rig/rig-core/tests/groq/extractor.rs b/rig/rig-core/tests/groq/extractor.rs index c3fcab162..0f542e419 100644 --- a/rig/rig-core/tests/groq/extractor.rs +++ b/rig/rig-core/tests/groq/extractor.rs @@ -10,7 +10,7 @@ use super::EXTRACTOR_MODEL; #[tokio::test] #[ignore = "requires GROQ_API_KEY"] async fn extractor_smoke() { - let client = groq::Client::from_env(); + let client = groq::Client::from_env().expect("client should build"); let extractor = client.extractor::(EXTRACTOR_MODEL).build(); let response = extractor diff --git a/rig/rig-core/tests/groq/extractor_usage.rs b/rig/rig-core/tests/groq/extractor_usage.rs index c4e17298e..c341647a7 100644 --- a/rig/rig-core/tests/groq/extractor_usage.rs +++ b/rig/rig-core/tests/groq/extractor_usage.rs @@ -45,7 +45,7 @@ fn assert_compatible_professions(left: Option<&str>, right: &str) { #[tokio::test] #[ignore = "requires GROQ_API_KEY"] async fn extract_backward_compatibility() -> Result<()> { - let client = groq::Client::from_env(); + let client = groq::Client::from_env().expect("client should build"); let extractor = client .extractor::(EXTRACTOR_USAGE_BACKWARD_MODEL) .build(); @@ -64,7 +64,7 @@ async fn extract_backward_compatibility() -> Result<()> { #[tokio::test] #[ignore = "requires GROQ_API_KEY"] async fn extract_with_usage_returns_data_and_usage() -> Result<()> { - let client = groq::Client::from_env(); + let client = groq::Client::from_env().expect("client should build"); let extractor = client .extractor::(EXTRACTOR_USAGE_WITH_USAGE_MODEL) .build(); @@ -86,7 +86,7 @@ async fn extract_with_usage_returns_data_and_usage() -> Result<()> { #[tokio::test] #[ignore = "requires GROQ_API_KEY"] async fn extract_with_chat_history_with_usage_works() -> Result<()> { - let client = groq::Client::from_env(); + let client = groq::Client::from_env().expect("client should build"); let extractor = client .extractor::
(EXTRACTOR_USAGE_CHAT_HISTORY_MODEL) .build(); @@ -115,7 +115,7 @@ async fn extract_with_chat_history_with_usage_works() -> Result<()> { #[tokio::test] #[ignore = "requires GROQ_API_KEY"] async fn extract_and_extract_with_usage_return_same_data() -> Result<()> { - let client = groq::Client::from_env(); + let client = groq::Client::from_env().expect("client should build"); let extractor = client .extractor::(EXTRACTOR_USAGE_SAME_DATA_MODEL) .build(); @@ -138,7 +138,7 @@ async fn extract_and_extract_with_usage_return_same_data() -> Result<()> { #[tokio::test] #[ignore = "requires GROQ_API_KEY"] async fn usage_tracking_works_for_different_schemas() -> Result<()> { - let client = groq::Client::from_env(); + let client = groq::Client::from_env().expect("client should build"); let person_extractor = client .extractor::(EXTRACTOR_USAGE_TRACKING_MODEL) diff --git a/rig/rig-core/tests/groq/loaders.rs b/rig/rig-core/tests/groq/loaders.rs index 13d8caf1a..8312d3764 100644 --- a/rig/rig-core/tests/groq/loaders.rs +++ b/rig/rig-core/tests/groq/loaders.rs @@ -12,7 +12,7 @@ use super::LOADERS_MODEL; #[tokio::test] #[ignore = "requires GROQ_API_KEY"] async fn loaders_smoke() { - let client = groq::Client::from_env(); + let client = groq::Client::from_env().expect("client should build"); let examples = FileLoader::with_glob(LOADERS_GLOB) .expect("examples glob should parse") .read_with_path() diff --git a/rig/rig-core/tests/groq/multi_extract.rs b/rig/rig-core/tests/groq/multi_extract.rs index 5d1c3784b..e461e236f 100644 --- a/rig/rig-core/tests/groq/multi_extract.rs +++ b/rig/rig-core/tests/groq/multi_extract.rs @@ -31,7 +31,7 @@ struct Sentiment { #[tokio::test] #[ignore = "requires GROQ_API_KEY"] async fn batch_multi_extract_chain() -> Result<()> { - let client = groq::Client::from_env(); + let client = groq::Client::from_env().expect("client should build"); let names_extractor = client .extractor::(MULTI_EXTRACT_NAMES_MODEL) .preamble("Extract names from the given text.") diff --git a/rig/rig-core/tests/groq/permission_control.rs b/rig/rig-core/tests/groq/permission_control.rs index 9d24c4773..77968699c 100644 --- a/rig/rig-core/tests/groq/permission_control.rs +++ b/rig/rig-core/tests/groq/permission_control.rs @@ -158,6 +158,7 @@ async fn permission_control_prompt_example() -> Result<()> { let _cleanup = FileCleanup::new()?; let agent = groq::Client::from_env() + .expect("client should build") .agent(PERMISSION_CONTROL_PROMPT_MODEL) .preamble("You are a helpful assistant that can read files using different methods.") .tool(ReadFileHead) @@ -196,6 +197,7 @@ async fn permission_control_streaming_example() -> Result<()> { let _cleanup = FileCleanup::new()?; let agent = groq::Client::from_env() + .expect("client should build") .agent(PERMISSION_CONTROL_STREAMING_MODEL) .preamble("You are a helpful assistant that can read files using different methods.") .tool(ReadFileHead) diff --git a/rig/rig-core/tests/groq/request_hook.rs b/rig/rig-core/tests/groq/request_hook.rs index 9abd81fce..c44a9f028 100644 --- a/rig/rig-core/tests/groq/request_hook.rs +++ b/rig/rig-core/tests/groq/request_hook.rs @@ -71,6 +71,7 @@ where #[ignore = "requires GROQ_API_KEY"] async fn request_hook_records_prompt_and_response() -> Result<()> { let agent = groq::Client::from_env() + .expect("client should build") .agent(REQUEST_HOOK_MODEL) .preamble("You are a comedian here to entertain the user using humour and jokes.") .build(); diff --git a/rig/rig-core/tests/groq/streaming.rs b/rig/rig-core/tests/groq/streaming.rs index abbc23556..70cdc6ca7 100644 --- a/rig/rig-core/tests/groq/streaming.rs +++ b/rig/rig-core/tests/groq/streaming.rs @@ -13,7 +13,7 @@ use super::STREAMING_MODEL; #[tokio::test] #[ignore = "requires GROQ_API_KEY"] async fn streaming_smoke() { - let client = groq::Client::from_env(); + let client = groq::Client::from_env().expect("client should build"); let agent = client .agent(STREAMING_MODEL) .preamble(STREAMING_PREAMBLE) diff --git a/rig/rig-core/tests/groq/streaming_reasoning.rs b/rig/rig-core/tests/groq/streaming_reasoning.rs index 47d662385..d155d9130 100644 --- a/rig/rig-core/tests/groq/streaming_reasoning.rs +++ b/rig/rig-core/tests/groq/streaming_reasoning.rs @@ -11,7 +11,7 @@ use super::STREAMING_REASONING_MODEL; #[tokio::test] #[ignore = "requires GROQ_API_KEY"] async fn parsed_reasoning_stream() { - let client = groq::Client::from_env(); + let client = groq::Client::from_env().expect("client should build"); let agent = client .agent(STREAMING_REASONING_MODEL) .preamble("You are a comedian here to entertain the user using humour and jokes.") diff --git a/rig/rig-core/tests/groq/streaming_tools.rs b/rig/rig-core/tests/groq/streaming_tools.rs index 0ada6fe9e..5e5599e85 100644 --- a/rig/rig-core/tests/groq/streaming_tools.rs +++ b/rig/rig-core/tests/groq/streaming_tools.rs @@ -25,7 +25,7 @@ use super::{ #[tokio::test] #[ignore = "requires GROQ_API_KEY"] async fn raw_stream_emits_required_zero_arg_tool_call() { - let client = groq::Client::from_env(); + let client = groq::Client::from_env().expect("client should build"); let model = client.completion_model(STREAMING_TOOLS_RAW_MODEL); let request = model .completion_request(REQUIRED_ZERO_ARG_TOOL_PROMPT) @@ -40,7 +40,7 @@ async fn raw_stream_emits_required_zero_arg_tool_call() { #[tokio::test] #[ignore = "requires GROQ_API_KEY"] async fn raw_stream_surfaces_two_distinct_tool_calls_before_text() { - let client = groq::Client::from_env(); + let client = groq::Client::from_env().expect("client should build"); let model = client.completion_model(STREAMING_TOOLS_RAW_MODEL); let request = model .completion_request(TWO_TOOL_STREAM_PROMPT) @@ -66,7 +66,7 @@ async fn raw_stream_surfaces_two_distinct_tool_calls_before_text() { #[tokio::test] #[ignore = "requires GROQ_API_KEY"] async fn streaming_tools_surface_two_distinct_tool_calls_before_final_answer() { - let client = groq::Client::from_env(); + let client = groq::Client::from_env().expect("client should build"); let agent = client .agent(STREAMING_TOOLS_MULTI_MODEL) .preamble(TWO_TOOL_STREAM_PREAMBLE) @@ -90,7 +90,7 @@ async fn streaming_tools_surface_two_distinct_tool_calls_before_final_answer() { #[tokio::test] #[ignore = "requires GROQ_API_KEY"] async fn streaming_tools_emit_tool_call_before_later_text() { - let client = groq::Client::from_env(); + let client = groq::Client::from_env().expect("client should build"); let agent = client .agent(STREAMING_TOOLS_ORDERED_MODEL) .preamble(ORDERED_TOOL_STREAM_PREAMBLE) @@ -113,7 +113,7 @@ async fn streaming_tools_emit_tool_call_before_later_text() { #[tokio::test] #[ignore = "requires GROQ_API_KEY"] async fn raw_followup_uses_tool_result_without_new_tool_calls() { - let client = groq::Client::from_env(); + let client = groq::Client::from_env().expect("client should build"); let model = client.completion_model(STREAMING_TOOLS_RAW_MODEL); let request = model .completion_request(ORDERED_TOOL_STREAM_PROMPT) diff --git a/rig/rig-core/tests/groq/tools.rs b/rig/rig-core/tests/groq/tools.rs index 5b4a3b93b..6d561fc66 100644 --- a/rig/rig-core/tests/groq/tools.rs +++ b/rig/rig-core/tests/groq/tools.rs @@ -11,7 +11,7 @@ use super::TOOLS_MODEL; #[tokio::test] #[ignore = "requires GROQ_API_KEY"] async fn tools_smoke() { - let client = groq::Client::from_env(); + let client = groq::Client::from_env().expect("client should build"); let agent = client .agent(TOOLS_MODEL) .preamble( diff --git a/rig/rig-core/tests/groq/transcription.rs b/rig/rig-core/tests/groq/transcription.rs index 933c15b14..0c7f32e9c 100644 --- a/rig/rig-core/tests/groq/transcription.rs +++ b/rig/rig-core/tests/groq/transcription.rs @@ -10,7 +10,7 @@ use crate::support::{AUDIO_FIXTURE_PATH, assert_nonempty_response}; #[tokio::test] #[ignore = "requires GROQ_API_KEY"] async fn transcription_smoke() { - let client = groq::Client::from_env(); + let client = groq::Client::from_env().expect("client should build"); let model = client.transcription_model(groq::WHISPER_LARGE_V3); let response = model .transcription_request() diff --git a/rig/rig-core/tests/groq/typed_prompt_tools.rs b/rig/rig-core/tests/groq/typed_prompt_tools.rs index 520b8ea30..7988ac36a 100644 --- a/rig/rig-core/tests/groq/typed_prompt_tools.rs +++ b/rig/rig-core/tests/groq/typed_prompt_tools.rs @@ -77,7 +77,7 @@ impl Tool for WeatherTool { #[ignore = "requires GROQ_API_KEY"] async fn prompt_typed_with_tool_call_roundtrip() -> Result<()> { let call_count = Arc::new(AtomicUsize::new(0)); - let client = groq::Client::from_env(); + let client = groq::Client::from_env().expect("client should build"); let agent = client .agent(TYPED_PROMPT_TOOLS_MODEL) .preamble( diff --git a/rig/rig-core/tests/huggingface.rs b/rig/rig-core/tests/huggingface.rs index d3e78bd5c..34e51fd63 100644 --- a/rig/rig-core/tests/huggingface.rs +++ b/rig/rig-core/tests/huggingface.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + //! Hugging Face integration tests. //! //! Run the full provider target with: diff --git a/rig/rig-core/tests/huggingface/agent.rs b/rig/rig-core/tests/huggingface/agent.rs index bf351fb2b..c4ec8fdb8 100644 --- a/rig/rig-core/tests/huggingface/agent.rs +++ b/rig/rig-core/tests/huggingface/agent.rs @@ -9,7 +9,7 @@ use crate::support::{BASIC_PREAMBLE, BASIC_PROMPT, assert_nonempty_response}; #[tokio::test] #[ignore = "requires HUGGINGFACE_API_KEY"] async fn completion_smoke() { - let client = huggingface::Client::from_env(); + let client = huggingface::Client::from_env().expect("client should build"); let agent = client .agent("deepseek-ai/DeepSeek-R1-Distill-Qwen-32B") .preamble(BASIC_PREAMBLE) diff --git a/rig/rig-core/tests/huggingface/context.rs b/rig/rig-core/tests/huggingface/context.rs index 8a6890db5..eaa3a4dde 100644 --- a/rig/rig-core/tests/huggingface/context.rs +++ b/rig/rig-core/tests/huggingface/context.rs @@ -9,7 +9,7 @@ use crate::support::{CONTEXT_DOCS, CONTEXT_PROMPT, assert_contains_any_case_inse #[tokio::test] #[ignore = "requires HUGGINGFACE_API_KEY"] async fn context_smoke() { - let client = huggingface::Client::from_env(); + let client = huggingface::Client::from_env().expect("client should build"); let agent = CONTEXT_DOCS .iter() .copied() diff --git a/rig/rig-core/tests/huggingface/image_generation.rs b/rig/rig-core/tests/huggingface/image_generation.rs index e521b5201..ecde405cd 100644 --- a/rig/rig-core/tests/huggingface/image_generation.rs +++ b/rig/rig-core/tests/huggingface/image_generation.rs @@ -10,7 +10,7 @@ use crate::support::{IMAGE_PROMPT, assert_nonempty_bytes}; #[tokio::test] #[ignore = "requires HUGGINGFACE_API_KEY"] async fn image_generation_smoke() { - let client = huggingface::Client::from_env(); + let client = huggingface::Client::from_env().expect("client should build"); let model = client.image_generation_model("stabilityai/stable-diffusion-3-medium-diffusers"); let response = model diff --git a/rig/rig-core/tests/huggingface/loaders.rs b/rig/rig-core/tests/huggingface/loaders.rs index 813c54b00..42dbe65cc 100644 --- a/rig/rig-core/tests/huggingface/loaders.rs +++ b/rig/rig-core/tests/huggingface/loaders.rs @@ -10,7 +10,7 @@ use crate::support::{LOADERS_GLOB, LOADERS_PROMPT, assert_loader_answer_is_relev #[tokio::test] #[ignore = "requires HUGGINGFACE_API_KEY"] async fn loaders_smoke() { - let client = huggingface::Client::from_env(); + let client = huggingface::Client::from_env().expect("client should build"); let examples = FileLoader::with_glob(LOADERS_GLOB) .expect("examples glob should parse") .read_with_path() diff --git a/rig/rig-core/tests/huggingface/streaming.rs b/rig/rig-core/tests/huggingface/streaming.rs index cd1bbabcc..83ce4565e 100644 --- a/rig/rig-core/tests/huggingface/streaming.rs +++ b/rig/rig-core/tests/huggingface/streaming.rs @@ -11,7 +11,7 @@ use crate::support::{ #[tokio::test] #[ignore = "requires HUGGINGFACE_API_KEY"] async fn streaming_smoke() { - let client = huggingface::Client::from_env(); + let client = huggingface::Client::from_env().expect("client should build"); let agent = client .agent("meta-llama/Meta-Llama-3.1-8B-Instruct") .preamble(STREAMING_PREAMBLE) diff --git a/rig/rig-core/tests/huggingface/tools.rs b/rig/rig-core/tests/huggingface/tools.rs index 10e06bb08..1e0ee8714 100644 --- a/rig/rig-core/tests/huggingface/tools.rs +++ b/rig/rig-core/tests/huggingface/tools.rs @@ -11,7 +11,7 @@ use crate::support::{ #[tokio::test] #[ignore = "requires HUGGINGFACE_API_KEY"] async fn tools_smoke() { - let client = huggingface::Client::from_env(); + let client = huggingface::Client::from_env().expect("client should build"); let agent = client .agent("deepseek-ai/DeepSeek-R1-Distill-Qwen-32B") .preamble(TOOLS_PREAMBLE) diff --git a/rig/rig-core/tests/huggingface/transcription.rs b/rig/rig-core/tests/huggingface/transcription.rs index a04313cc4..f4f292166 100644 --- a/rig/rig-core/tests/huggingface/transcription.rs +++ b/rig/rig-core/tests/huggingface/transcription.rs @@ -10,7 +10,7 @@ use crate::support::{AUDIO_FIXTURE_PATH, assert_nonempty_response}; #[tokio::test] #[ignore = "requires HUGGINGFACE_API_KEY"] async fn transcription_smoke() { - let client = huggingface::Client::from_env(); + let client = huggingface::Client::from_env().expect("client should build"); let model = client.transcription_model("whisper-large-v3"); let response = model .transcription_request() diff --git a/rig/rig-core/tests/hyperbolic.rs b/rig/rig-core/tests/hyperbolic.rs index 4dcb9da1c..26bb3e114 100644 --- a/rig/rig-core/tests/hyperbolic.rs +++ b/rig/rig-core/tests/hyperbolic.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + //! Hyperbolic integration tests. //! //! Run the full provider target with: diff --git a/rig/rig-core/tests/hyperbolic/agent.rs b/rig/rig-core/tests/hyperbolic/agent.rs index 25454b3b7..2b90f7fff 100644 --- a/rig/rig-core/tests/hyperbolic/agent.rs +++ b/rig/rig-core/tests/hyperbolic/agent.rs @@ -9,7 +9,7 @@ use crate::support::{BASIC_PREAMBLE, BASIC_PROMPT, assert_nonempty_response}; #[tokio::test] #[ignore = "requires HYPERBOLIC_API_KEY"] async fn completion_smoke() { - let client = hyperbolic::Client::from_env(); + let client = hyperbolic::Client::from_env().expect("client should build"); let agent = client .agent(hyperbolic::DEEPSEEK_R1) .preamble(BASIC_PREAMBLE) diff --git a/rig/rig-core/tests/hyperbolic/audio_generation.rs b/rig/rig-core/tests/hyperbolic/audio_generation.rs index 43fbeb5e3..b2c83b04c 100644 --- a/rig/rig-core/tests/hyperbolic/audio_generation.rs +++ b/rig/rig-core/tests/hyperbolic/audio_generation.rs @@ -10,7 +10,7 @@ use crate::support::{AUDIO_TEXT, assert_nonempty_bytes}; #[tokio::test] #[ignore = "requires HYPERBOLIC_API_KEY"] async fn audio_generation_smoke() { - let client = hyperbolic::Client::from_env(); + let client = hyperbolic::Client::from_env().expect("client should build"); let model = client.audio_generation_model("EN"); let response = model diff --git a/rig/rig-core/tests/hyperbolic/image_generation.rs b/rig/rig-core/tests/hyperbolic/image_generation.rs index 23911a4c2..c89745466 100644 --- a/rig/rig-core/tests/hyperbolic/image_generation.rs +++ b/rig/rig-core/tests/hyperbolic/image_generation.rs @@ -10,7 +10,7 @@ use crate::support::{IMAGE_PROMPT, assert_nonempty_bytes}; #[tokio::test] #[ignore = "requires HYPERBOLIC_API_KEY"] async fn image_generation_smoke() { - let client = hyperbolic::Client::from_env(); + let client = hyperbolic::Client::from_env().expect("client should build"); let model = client.image_generation_model(hyperbolic::SDXL_TURBO); let response = model diff --git a/rig/rig-core/tests/llamacpp.rs b/rig/rig-core/tests/llamacpp.rs index 3f7e33195..da91bdeb7 100644 --- a/rig/rig-core/tests/llamacpp.rs +++ b/rig/rig-core/tests/llamacpp.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + //! llama.cpp OpenAI-compatible integration tests. //! //! Run the full provider target with: diff --git a/rig/rig-core/tests/llamafile.rs b/rig/rig-core/tests/llamafile.rs index c7c604549..7a595d34d 100644 --- a/rig/rig-core/tests/llamafile.rs +++ b/rig/rig-core/tests/llamafile.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + //! Llamafile integration tests. //! //! Run the full provider target with: diff --git a/rig/rig-core/tests/llamafile/support.rs b/rig/rig-core/tests/llamafile/support.rs index 438ec9433..cbca80676 100644 --- a/rig/rig-core/tests/llamafile/support.rs +++ b/rig/rig-core/tests/llamafile/support.rs @@ -15,7 +15,7 @@ pub(super) fn model_name() -> String { } pub(super) fn client() -> llamafile::Client { - llamafile::Client::from_url(&api_base_url()) + llamafile::Client::from_url(&api_base_url()).expect("client should build") } fn server_addr() -> Option { diff --git a/rig/rig-core/tests/minimax.rs b/rig/rig-core/tests/minimax.rs index 0202b9b7b..39349cc8b 100644 --- a/rig/rig-core/tests/minimax.rs +++ b/rig/rig-core/tests/minimax.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + //! MiniMax integration tests. //! //! Run the provider target with: diff --git a/rig/rig-core/tests/minimax/anthropic.rs b/rig/rig-core/tests/minimax/anthropic.rs index 3b4d95e9f..ede37fad3 100644 --- a/rig/rig-core/tests/minimax/anthropic.rs +++ b/rig/rig-core/tests/minimax/anthropic.rs @@ -10,6 +10,7 @@ use crate::support::{BASIC_PREAMBLE, BASIC_PROMPT, assert_nonempty_response}; #[ignore = "requires MINIMAX_API_KEY"] async fn anthropic_compatible_completion_smoke() { let response = minimax::AnthropicClient::from_env() + .expect("client should build") .agent(minimax::MINIMAX_M2) .preamble(BASIC_PREAMBLE) .build() diff --git a/rig/rig-core/tests/minimax/openai.rs b/rig/rig-core/tests/minimax/openai.rs index b2fd84f38..f84bd66dc 100644 --- a/rig/rig-core/tests/minimax/openai.rs +++ b/rig/rig-core/tests/minimax/openai.rs @@ -10,6 +10,7 @@ use crate::support::{BASIC_PREAMBLE, BASIC_PROMPT, assert_nonempty_response}; #[ignore = "requires MINIMAX_API_KEY"] async fn openai_compatible_completion_smoke() { let response = minimax::Client::from_env() + .expect("client should build") .agent(minimax::MINIMAX_M2_7) .preamble(BASIC_PREAMBLE) .build() diff --git a/rig/rig-core/tests/mira.rs b/rig/rig-core/tests/mira.rs index 8899e7b34..67cd9c3db 100644 --- a/rig/rig-core/tests/mira.rs +++ b/rig/rig-core/tests/mira.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + //! Mira integration tests. //! //! Run the full provider target with: diff --git a/rig/rig-core/tests/mira/agent.rs b/rig/rig-core/tests/mira/agent.rs index 17a38f65e..723f3b66b 100644 --- a/rig/rig-core/tests/mira/agent.rs +++ b/rig/rig-core/tests/mira/agent.rs @@ -9,7 +9,7 @@ use crate::support::{BASIC_PREAMBLE, BASIC_PROMPT, assert_nonempty_response}; #[tokio::test] #[ignore = "requires MIRA_API_KEY"] async fn completion_smoke() { - let client = mira::Client::from_env(); + let client = mira::Client::from_env().expect("client should build"); let agent = client .agent(openai::GPT_4O) .preamble(BASIC_PREAMBLE) diff --git a/rig/rig-core/tests/mira/models.rs b/rig/rig-core/tests/mira/models.rs index 7e8fbccd1..cfc3329d9 100644 --- a/rig/rig-core/tests/mira/models.rs +++ b/rig/rig-core/tests/mira/models.rs @@ -6,7 +6,7 @@ use rig::providers::mira; #[tokio::test] #[ignore = "requires MIRA_API_KEY"] async fn list_models_smoke() { - let client = mira::Client::from_env(); + let client = mira::Client::from_env().expect("client should build"); let models = client .list_models() .await diff --git a/rig/rig-core/tests/mira/tools.rs b/rig/rig-core/tests/mira/tools.rs index aa77aeaa8..328fd1c04 100644 --- a/rig/rig-core/tests/mira/tools.rs +++ b/rig/rig-core/tests/mira/tools.rs @@ -11,7 +11,7 @@ use crate::support::{ #[tokio::test] #[ignore = "requires MIRA_API_KEY"] async fn tools_smoke() { - let client = mira::Client::from_env(); + let client = mira::Client::from_env().expect("client should build"); let agent = client .agent(anthropic::completion::CLAUDE_SONNET_4_6) .preamble(TOOLS_PREAMBLE) diff --git a/rig/rig-core/tests/mistral.rs b/rig/rig-core/tests/mistral.rs index 647a3fb6b..ed114a82f 100644 --- a/rig/rig-core/tests/mistral.rs +++ b/rig/rig-core/tests/mistral.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + //! Mistral integration tests. //! //! Run the full provider target with: diff --git a/rig/rig-core/tests/mistral/agent.rs b/rig/rig-core/tests/mistral/agent.rs index a12edaebb..28cfd6823 100644 --- a/rig/rig-core/tests/mistral/agent.rs +++ b/rig/rig-core/tests/mistral/agent.rs @@ -11,7 +11,7 @@ use super::DEFAULT_MODEL; #[tokio::test] #[ignore = "requires MISTRAL_API_KEY"] async fn completion_smoke() { - let client = mistral::Client::from_env(); + let client = mistral::Client::from_env().expect("client should build"); let agent = client.agent(DEFAULT_MODEL).preamble(BASIC_PREAMBLE).build(); let response = agent diff --git a/rig/rig-core/tests/mistral/embeddings.rs b/rig/rig-core/tests/mistral/embeddings.rs index c001d4fe6..a4eb337f0 100644 --- a/rig/rig-core/tests/mistral/embeddings.rs +++ b/rig/rig-core/tests/mistral/embeddings.rs @@ -18,7 +18,7 @@ struct Greetings { #[tokio::test] #[ignore = "requires MISTRAL_API_KEY and --features derive"] async fn derive_embeddings_and_vector_search() { - let client = mistral::Client::from_env(); + let client = mistral::Client::from_env().expect("client should build"); let embedding_model = client.embedding_model(mistral::embedding::MISTRAL_EMBED); let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) .document(Greetings { diff --git a/rig/rig-core/tests/mistral/extractor.rs b/rig/rig-core/tests/mistral/extractor.rs index 5e42bd801..d1774c98a 100644 --- a/rig/rig-core/tests/mistral/extractor.rs +++ b/rig/rig-core/tests/mistral/extractor.rs @@ -10,7 +10,7 @@ use super::DEFAULT_MODEL; #[tokio::test] #[ignore = "requires MISTRAL_API_KEY"] async fn extractor_smoke() { - let client = mistral::Client::from_env(); + let client = mistral::Client::from_env().expect("client should build"); let extractor = client.extractor::(DEFAULT_MODEL).build(); let response = extractor diff --git a/rig/rig-core/tests/mistral/extractor_usage.rs b/rig/rig-core/tests/mistral/extractor_usage.rs index bf2d6ee4b..e083778a3 100644 --- a/rig/rig-core/tests/mistral/extractor_usage.rs +++ b/rig/rig-core/tests/mistral/extractor_usage.rs @@ -41,7 +41,7 @@ fn assert_compatible_professions(left: Option<&str>, right: &str) { #[tokio::test] #[ignore = "requires MISTRAL_API_KEY"] async fn extract_backward_compatibility() -> Result<()> { - let client = mistral::Client::from_env(); + let client = mistral::Client::from_env().expect("client should build"); let extractor = client.extractor::(DEFAULT_MODEL).build(); let person = extractor @@ -58,7 +58,7 @@ async fn extract_backward_compatibility() -> Result<()> { #[tokio::test] #[ignore = "requires MISTRAL_API_KEY"] async fn extract_with_usage_returns_data_and_usage() -> Result<()> { - let client = mistral::Client::from_env(); + let client = mistral::Client::from_env().expect("client should build"); let extractor = client.extractor::(DEFAULT_MODEL).build(); let response: ExtractionResponse = extractor @@ -78,7 +78,7 @@ async fn extract_with_usage_returns_data_and_usage() -> Result<()> { #[tokio::test] #[ignore = "requires MISTRAL_API_KEY"] async fn extract_with_chat_history_with_usage_works() -> Result<()> { - let client = mistral::Client::from_env(); + let client = mistral::Client::from_env().expect("client should build"); let extractor = client.extractor::
(DEFAULT_MODEL).build(); let chat_history = vec![Message::user( @@ -105,7 +105,7 @@ async fn extract_with_chat_history_with_usage_works() -> Result<()> { #[tokio::test] #[ignore = "requires MISTRAL_API_KEY"] async fn extract_and_extract_with_usage_return_same_data() -> Result<()> { - let client = mistral::Client::from_env(); + let client = mistral::Client::from_env().expect("client should build"); let extractor = client.extractor::(DEFAULT_MODEL).build(); let text = "Bob Johnson is a 55 year old retired teacher."; @@ -126,7 +126,7 @@ async fn extract_and_extract_with_usage_return_same_data() -> Result<()> { #[tokio::test] #[ignore = "requires MISTRAL_API_KEY"] async fn usage_tracking_works_for_different_schemas() -> Result<()> { - let client = mistral::Client::from_env(); + let client = mistral::Client::from_env().expect("client should build"); let person_extractor = client.extractor::(DEFAULT_MODEL).build(); let person_response = person_extractor diff --git a/rig/rig-core/tests/mistral/models.rs b/rig/rig-core/tests/mistral/models.rs index 6e6a00bbd..b135480a9 100644 --- a/rig/rig-core/tests/mistral/models.rs +++ b/rig/rig-core/tests/mistral/models.rs @@ -6,7 +6,7 @@ use rig::providers::mistral; #[tokio::test] #[ignore = "requires MISTRAL_API_KEY"] async fn list_models_smoke() { - let client = mistral::Client::from_env(); + let client = mistral::Client::from_env().expect("client should build"); let models = match client.list_models().await { Ok(models) => models, Err(error) => { diff --git a/rig/rig-core/tests/mistral/multi_extract.rs b/rig/rig-core/tests/mistral/multi_extract.rs index 9864c3a7d..726837eab 100644 --- a/rig/rig-core/tests/mistral/multi_extract.rs +++ b/rig/rig-core/tests/mistral/multi_extract.rs @@ -74,7 +74,7 @@ fn assert_sentiment_shape(extract: &CombinedExtract) { #[tokio::test] #[ignore = "requires MISTRAL_API_KEY"] async fn batch_multi_extract_chain() -> Result<()> { - let client = mistral::Client::from_env(); + let client = mistral::Client::from_env().expect("client should build"); let names_extractor = client .extractor::(DEFAULT_MODEL) .preamble("Extract names from the given text.") diff --git a/rig/rig-core/tests/mistral/permission_control.rs b/rig/rig-core/tests/mistral/permission_control.rs index 992ba0d23..a62f8b36d 100644 --- a/rig/rig-core/tests/mistral/permission_control.rs +++ b/rig/rig-core/tests/mistral/permission_control.rs @@ -156,6 +156,7 @@ async fn permission_control_prompt_example() -> Result<()> { let _cleanup = FileCleanup::new()?; let agent = mistral::Client::from_env() + .expect("client should build") .agent(TOOL_MODEL) .preamble("You are a helpful assistant that can read files using different methods.") .tool(ReadFileHead) @@ -191,6 +192,7 @@ async fn permission_control_streaming_example() -> Result<()> { let _cleanup = FileCleanup::new()?; let agent = mistral::Client::from_env() + .expect("client should build") .agent(TOOL_MODEL) .preamble("You are a helpful assistant that can read files using different methods.") .tool(ReadFileHead) diff --git a/rig/rig-core/tests/mistral/request_hook.rs b/rig/rig-core/tests/mistral/request_hook.rs index c865a4e4f..b829892f7 100644 --- a/rig/rig-core/tests/mistral/request_hook.rs +++ b/rig/rig-core/tests/mistral/request_hook.rs @@ -71,6 +71,7 @@ where #[ignore = "requires MISTRAL_API_KEY"] async fn request_hook_records_prompt_and_response() -> Result<()> { let agent = mistral::Client::from_env() + .expect("client should build") .agent(DEFAULT_MODEL) .preamble("You are a comedian here to entertain the user using humour and jokes.") .build(); diff --git a/rig/rig-core/tests/mistral/streaming.rs b/rig/rig-core/tests/mistral/streaming.rs index 38cddae1c..d6e0f0aca 100644 --- a/rig/rig-core/tests/mistral/streaming.rs +++ b/rig/rig-core/tests/mistral/streaming.rs @@ -13,7 +13,7 @@ use super::DEFAULT_MODEL; #[tokio::test] #[ignore = "requires MISTRAL_API_KEY"] async fn streaming_smoke() { - let client = mistral::Client::from_env(); + let client = mistral::Client::from_env().expect("client should build"); let agent = client .agent(DEFAULT_MODEL) .preamble(STREAMING_PREAMBLE) @@ -30,7 +30,7 @@ async fn streaming_smoke() { #[tokio::test] #[ignore = "requires MISTRAL_API_KEY"] async fn example_streaming_prompt() { - let client = mistral::Client::from_env(); + let client = mistral::Client::from_env().expect("client should build"); let agent = client .agent(DEFAULT_MODEL) .preamble("Be precise and concise.") diff --git a/rig/rig-core/tests/mistral/streaming_tools.rs b/rig/rig-core/tests/mistral/streaming_tools.rs index ea375cd46..9dd21e8f3 100644 --- a/rig/rig-core/tests/mistral/streaming_tools.rs +++ b/rig/rig-core/tests/mistral/streaming_tools.rs @@ -17,7 +17,7 @@ use super::TOOL_MODEL; #[tokio::test] #[ignore = "requires MISTRAL_API_KEY"] async fn streaming_tools_smoke() { - let client = mistral::Client::from_env(); + let client = mistral::Client::from_env().expect("client should build"); let agent = client .agent(TOOL_MODEL) .preamble(STREAMING_TOOLS_PREAMBLE) @@ -37,7 +37,7 @@ async fn streaming_tools_smoke() { #[tokio::test] #[ignore = "requires MISTRAL_API_KEY"] async fn example_streaming_with_tools() { - let client = mistral::Client::from_env(); + let client = mistral::Client::from_env().expect("client should build"); let agent = client .agent(TOOL_MODEL) .preamble( @@ -60,7 +60,7 @@ async fn example_streaming_with_tools() { #[tokio::test] #[ignore = "requires MISTRAL_API_KEY"] async fn stream_prompt_tool_roundtrip_preserves_streaming_contract() { - let client = mistral::Client::from_env(); + let client = mistral::Client::from_env().expect("client should build"); let agent = client .agent(TOOL_MODEL) .preamble(ORDERED_TOOL_STREAM_PREAMBLE) @@ -84,7 +84,7 @@ async fn stream_prompt_tool_roundtrip_preserves_streaming_contract() { #[tokio::test] #[ignore = "requires MISTRAL_API_KEY"] async fn stream_chat_tool_roundtrip_preserves_streaming_contract() { - let client = mistral::Client::from_env(); + let client = mistral::Client::from_env().expect("client should build"); let agent = client .agent(TOOL_MODEL) .preamble(ORDERED_TOOL_STREAM_PREAMBLE) diff --git a/rig/rig-core/tests/mistral/transcription.rs b/rig/rig-core/tests/mistral/transcription.rs index 7599fd2f4..4886917f4 100644 --- a/rig/rig-core/tests/mistral/transcription.rs +++ b/rig/rig-core/tests/mistral/transcription.rs @@ -10,7 +10,7 @@ use crate::support::{AUDIO_FIXTURE_PATH, assert_nonempty_response}; #[tokio::test] #[ignore = "requires MISTRAL_API_KEY"] async fn transcription_smoke() { - let client = mistral::Client::from_env(); + let client = mistral::Client::from_env().expect("client should build"); let model = client.transcription_model(mistral::VOXTRAL_MINI); let response = model .transcription_request() diff --git a/rig/rig-core/tests/mistral/typed_prompt_tools.rs b/rig/rig-core/tests/mistral/typed_prompt_tools.rs index 2b4b4da1a..50f2ed9ce 100644 --- a/rig/rig-core/tests/mistral/typed_prompt_tools.rs +++ b/rig/rig-core/tests/mistral/typed_prompt_tools.rs @@ -77,7 +77,7 @@ impl Tool for WeatherTool { #[ignore = "requires MISTRAL_API_KEY"] async fn prompt_typed_with_tool_call_roundtrip() -> Result<()> { let call_count = Arc::new(AtomicUsize::new(0)); - let client = mistral::Client::from_env(); + let client = mistral::Client::from_env().expect("client should build"); let agent = client .agent(TOOL_MODEL) .preamble( diff --git a/rig/rig-core/tests/moonshot.rs b/rig/rig-core/tests/moonshot.rs index 72b2efe64..3daef3fb9 100644 --- a/rig/rig-core/tests/moonshot.rs +++ b/rig/rig-core/tests/moonshot.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + //! Moonshot integration tests. //! //! Run the provider target with: diff --git a/rig/rig-core/tests/moonshot/agent.rs b/rig/rig-core/tests/moonshot/agent.rs index 0a31598e6..5e16b9721 100644 --- a/rig/rig-core/tests/moonshot/agent.rs +++ b/rig/rig-core/tests/moonshot/agent.rs @@ -9,7 +9,7 @@ use crate::support::{BASIC_PREAMBLE, BASIC_PROMPT, assert_nonempty_response}; #[tokio::test] #[ignore = "requires MOONSHOT_API_KEY"] async fn completion_smoke() { - let client = moonshot::Client::from_env(); + let client = moonshot::Client::from_env().expect("moonshot client should build"); let agent = client .agent(moonshot::MOONSHOT_CHAT) .preamble(BASIC_PREAMBLE) diff --git a/rig/rig-core/tests/moonshot/anthropic.rs b/rig/rig-core/tests/moonshot/anthropic.rs index d9df4be14..403f7db26 100644 --- a/rig/rig-core/tests/moonshot/anthropic.rs +++ b/rig/rig-core/tests/moonshot/anthropic.rs @@ -10,6 +10,7 @@ use crate::support::{BASIC_PREAMBLE, BASIC_PROMPT, assert_nonempty_response}; #[ignore = "requires MOONSHOT_API_KEY"] async fn anthropic_compatible_completion_smoke() { let response = moonshot::AnthropicClient::from_env() + .expect("moonshot anthropic client should build") .agent(moonshot::KIMI_K2_5) .preamble(BASIC_PREAMBLE) .build() diff --git a/rig/rig-core/tests/moonshot/context.rs b/rig/rig-core/tests/moonshot/context.rs index 7b46cca42..486d24d4f 100644 --- a/rig/rig-core/tests/moonshot/context.rs +++ b/rig/rig-core/tests/moonshot/context.rs @@ -9,7 +9,7 @@ use crate::support::{CONTEXT_DOCS, CONTEXT_PROMPT, assert_contains_any_case_inse #[tokio::test] #[ignore = "requires MOONSHOT_API_KEY"] async fn context_smoke() { - let client = moonshot::Client::from_env(); + let client = moonshot::Client::from_env().expect("moonshot client should build"); let agent = CONTEXT_DOCS .iter() .copied() diff --git a/rig/rig-core/tests/moonshot/reasoning_history.rs b/rig/rig-core/tests/moonshot/reasoning_history.rs index 81cae1baf..489efcff2 100644 --- a/rig/rig-core/tests/moonshot/reasoning_history.rs +++ b/rig/rig-core/tests/moonshot/reasoning_history.rs @@ -21,7 +21,9 @@ fn response_text(choice: &rig::OneOrMany) -> String { #[tokio::test] #[ignore = "requires MOONSHOT_API_KEY"] async fn assistant_reasoning_content_roundtrips_in_history() { - let model = moonshot::Client::from_env().completion_model(moonshot::KIMI_K2_5); + let model = moonshot::Client::from_env() + .expect("moonshot client should build") + .completion_model(moonshot::KIMI_K2_5); let assistant = Message::Assistant { id: None, content: OneOrMany::many(vec![ diff --git a/rig/rig-core/tests/moonshot/tools.rs b/rig/rig-core/tests/moonshot/tools.rs index 9110d56ef..ca9e8b6cc 100644 --- a/rig/rig-core/tests/moonshot/tools.rs +++ b/rig/rig-core/tests/moonshot/tools.rs @@ -13,6 +13,7 @@ use crate::support::{ #[ignore = "requires MOONSHOT_API_KEY"] async fn required_tool_choice_agent_roundtrip() { let agent = moonshot::Client::from_env() + .expect("moonshot client should build") .agent(moonshot::KIMI_K2_5) .preamble(TOOLS_PREAMBLE) .tool_choice(ToolChoice::Required) diff --git a/rig/rig-core/tests/ollama.rs b/rig/rig-core/tests/ollama.rs index 9c91f34cd..c31900e96 100644 --- a/rig/rig-core/tests/ollama.rs +++ b/rig/rig-core/tests/ollama.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + //! Ollama integration tests. //! //! Run the full provider target with: diff --git a/rig/rig-core/tests/ollama/multimodal.rs b/rig/rig-core/tests/ollama/multimodal.rs index 0a889fd81..a56781d5f 100644 --- a/rig/rig-core/tests/ollama/multimodal.rs +++ b/rig/rig-core/tests/ollama/multimodal.rs @@ -15,7 +15,7 @@ use crate::support::{ #[tokio::test] #[ignore = "requires a local Ollama server with a multimodal model"] async fn multimodal_image_prompt() { - let client = ollama::Client::from_env(); + let client = ollama::Client::from_env().expect("client should build"); let agent = client .agent("llava") .preamble("Describe this image and include anything notable about it.") diff --git a/rig/rig-core/tests/ollama/pause_control.rs b/rig/rig-core/tests/ollama/pause_control.rs index 78358668b..1fdc64956 100644 --- a/rig/rig-core/tests/ollama/pause_control.rs +++ b/rig/rig-core/tests/ollama/pause_control.rs @@ -10,7 +10,9 @@ use tokio::time::{Duration, sleep}; #[tokio::test] #[ignore = "requires a local Ollama server"] async fn streaming_pause_and_resume() { - let model = ollama::Client::from_env().completion_model("gemma3:4b"); + let model = ollama::Client::from_env() + .expect("client should build") + .completion_model("gemma3:4b"); let request = model .completion_request("Explain backpropagation in neural networks.") .preamble("You are a helpful AI assistant. Provide concise explanations.".to_string()) diff --git a/rig/rig-core/tests/ollama/streaming.rs b/rig/rig-core/tests/ollama/streaming.rs index f66c7696f..3fae609f9 100644 --- a/rig/rig-core/tests/ollama/streaming.rs +++ b/rig/rig-core/tests/ollama/streaming.rs @@ -10,6 +10,7 @@ use crate::support::{assert_nonempty_response, collect_stream_final_response}; #[ignore = "requires a local Ollama server"] async fn example_streaming_prompt() { let agent = ollama::Client::from_env() + .expect("client should build") .agent("llama3.2") .preamble("Be precise and concise.") .temperature(0.5) diff --git a/rig/rig-core/tests/ollama/streaming_tools.rs b/rig/rig-core/tests/ollama/streaming_tools.rs index bf8df0f36..5a8286e66 100644 --- a/rig/rig-core/tests/ollama/streaming_tools.rs +++ b/rig/rig-core/tests/ollama/streaming_tools.rs @@ -12,6 +12,7 @@ use crate::support::{ #[ignore = "requires a local Ollama server"] async fn example_streaming_with_tools() { let agent = ollama::Client::from_env() + .expect("client should build") .agent("llama3.2") .preamble( "You are a calculator here to help the user perform arithmetic operations. \ diff --git a/rig/rig-core/tests/ollama/structured_output.rs b/rig/rig-core/tests/ollama/structured_output.rs index 4d96cc81f..62757f312 100644 --- a/rig/rig-core/tests/ollama/structured_output.rs +++ b/rig/rig-core/tests/ollama/structured_output.rs @@ -19,7 +19,7 @@ struct Character { #[tokio::test] #[ignore = "requires a local Ollama server"] async fn structured_output_prompt() { - let client = ollama::Client::from_env(); + let client = ollama::Client::from_env().expect("client should build"); let agent = client .agent("qwen3:4b") .preamble("You are a creative fiction writer. Create detailed characters.") diff --git a/rig/rig-core/tests/openai.rs b/rig/rig-core/tests/openai.rs index acbfea0bd..a1f19a4e0 100644 --- a/rig/rig-core/tests/openai.rs +++ b/rig/rig-core/tests/openai.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + //! OpenAI integration tests. //! //! Run the full provider target with: diff --git a/rig/rig-core/tests/openai/agent.rs b/rig/rig-core/tests/openai/agent.rs index 096f036a9..9d34beaae 100644 --- a/rig/rig-core/tests/openai/agent.rs +++ b/rig/rig-core/tests/openai/agent.rs @@ -9,7 +9,7 @@ use crate::support::{BASIC_PREAMBLE, BASIC_PROMPT, assert_nonempty_response}; #[tokio::test] #[ignore = "requires OPENAI_API_KEY"] async fn completion_smoke() { - let client = openai::Client::from_env(); + let client = openai::Client::from_env().expect("client should build"); let agent = client .agent(openai::GPT_4O) .preamble(BASIC_PREAMBLE) diff --git a/rig/rig-core/tests/openai/audio_generation.rs b/rig/rig-core/tests/openai/audio_generation.rs index 6e53dce6f..f0dfcb0aa 100644 --- a/rig/rig-core/tests/openai/audio_generation.rs +++ b/rig/rig-core/tests/openai/audio_generation.rs @@ -10,7 +10,7 @@ use crate::support::{AUDIO_TEXT, assert_nonempty_bytes}; #[tokio::test] #[ignore = "requires OPENAI_API_KEY"] async fn audio_generation_smoke() { - let client = openai::Client::from_env(); + let client = openai::Client::from_env().expect("client should build"); let model = client.audio_generation_model(openai::TTS_1); let response = model diff --git a/rig/rig-core/tests/openai/completions_api.rs b/rig/rig-core/tests/openai/completions_api.rs index 61f2b113e..48925dc7e 100644 --- a/rig/rig-core/tests/openai/completions_api.rs +++ b/rig/rig-core/tests/openai/completions_api.rs @@ -26,6 +26,7 @@ use crate::support::{ #[ignore = "requires OPENAI_API_KEY"] async fn completions_api_agent_prompt() { let agent = openai::Client::from_env() + .expect("client should build") .completion_model(openai::GPT_4O) .completions_api() .into_agent_builder() @@ -43,7 +44,9 @@ async fn completions_api_agent_prompt() { #[tokio::test] #[ignore = "requires OPENAI_API_KEY"] async fn completions_api_raw_response_text_matches_normalized_choice_text() { - let client = openai::Client::from_env().completions_api(); + let client = openai::Client::from_env() + .expect("client should build") + .completions_api(); let response = client .completion_model(openai::GPT_4O) .completion_request(RAW_TEXT_RESPONSE_PROMPT) @@ -68,7 +71,9 @@ async fn completions_api_raw_response_text_matches_normalized_choice_text() { #[tokio::test] #[ignore = "requires OPENAI_API_KEY"] async fn completions_api_streams_two_tool_calls_before_final_answer() { - let client = openai::Client::from_env().completions_api(); + let client = openai::Client::from_env() + .expect("client should build") + .completions_api(); let agent = client .agent(openai::GPT_4O) .preamble(TWO_TOOL_STREAM_PREAMBLE) @@ -92,7 +97,9 @@ async fn completions_api_streams_two_tool_calls_before_final_answer() { #[tokio::test] #[ignore = "requires OPENAI_API_KEY"] async fn completions_api_raw_stream_emits_required_zero_arg_tool_call() { - let client = openai::Client::from_env().completions_api(); + let client = openai::Client::from_env() + .expect("client should build") + .completions_api(); let model = client.completion_model(openai::GPT_4O); let request = model .completion_request(REQUIRED_ZERO_ARG_TOOL_PROMPT) @@ -107,7 +114,9 @@ async fn completions_api_raw_stream_emits_required_zero_arg_tool_call() { #[tokio::test] #[ignore = "requires OPENAI_API_KEY"] async fn completions_api_raw_stream_surfaces_two_distinct_tool_calls_before_text() { - let client = openai::Client::from_env().completions_api(); + let client = openai::Client::from_env() + .expect("client should build") + .completions_api(); let model = client.completion_model(openai::GPT_4O); let request = model .completion_request(TWO_TOOL_STREAM_PROMPT) @@ -133,7 +142,9 @@ async fn completions_api_raw_stream_surfaces_two_distinct_tool_calls_before_text #[tokio::test] #[ignore = "requires OPENAI_API_KEY"] async fn completions_api_stream_emits_tool_call_before_later_text() { - let client = openai::Client::from_env().completions_api(); + let client = openai::Client::from_env() + .expect("client should build") + .completions_api(); let agent = client .agent(openai::GPT_4O) .preamble(ORDERED_TOOL_STREAM_PREAMBLE) @@ -156,7 +167,9 @@ async fn completions_api_stream_emits_tool_call_before_later_text() { #[tokio::test] #[ignore = "requires OPENAI_API_KEY"] async fn completions_api_raw_followup_uses_tool_result_without_new_tool_calls() { - let client = openai::Client::from_env().completions_api(); + let client = openai::Client::from_env() + .expect("client should build") + .completions_api(); let model = client.completion_model(openai::GPT_4O); let request = model .completion_request(ORDERED_TOOL_STREAM_PROMPT) diff --git a/rig/rig-core/tests/openai/extractor.rs b/rig/rig-core/tests/openai/extractor.rs index cb60bb3ac..77f0cb2e2 100644 --- a/rig/rig-core/tests/openai/extractor.rs +++ b/rig/rig-core/tests/openai/extractor.rs @@ -8,7 +8,7 @@ use crate::support::{EXTRACTOR_TEXT, SmokePerson, assert_nonempty_response}; #[tokio::test] #[ignore = "requires OPENAI_API_KEY"] async fn extractor_smoke() { - let client = openai::Client::from_env(); + let client = openai::Client::from_env().expect("client should build"); let extractor = client.extractor::(openai::GPT_4O).build(); let response = extractor diff --git a/rig/rig-core/tests/openai/extractor_usage.rs b/rig/rig-core/tests/openai/extractor_usage.rs index b37ecad8d..e96ab4f24 100644 --- a/rig/rig-core/tests/openai/extractor_usage.rs +++ b/rig/rig-core/tests/openai/extractor_usage.rs @@ -49,7 +49,7 @@ fn assert_compatible_professions(left: Option<&str>, right: Option<&str>) { #[tokio::test] #[ignore = "This requires an API key"] async fn extract_backward_compatibility() -> Result<()> { - let client = providers::openai::Client::from_env(); + let client = providers::openai::Client::from_env().expect("client should build"); let extractor = client .extractor::(providers::openai::GPT_4O_MINI) .build(); @@ -69,7 +69,7 @@ async fn extract_backward_compatibility() -> Result<()> { #[tokio::test] #[ignore = "This requires an API key"] async fn extract_with_usage_returns_data_and_usage() -> Result<()> { - let client = providers::openai::Client::from_env(); + let client = providers::openai::Client::from_env().expect("client should build"); let extractor = client .extractor::(providers::openai::GPT_4O_MINI) .build(); @@ -97,7 +97,7 @@ async fn extract_with_usage_returns_data_and_usage() -> Result<()> { async fn extract_with_chat_history_with_usage_works() -> Result<()> { use rig::message::Message; - let client = providers::openai::Client::from_env(); + let client = providers::openai::Client::from_env().expect("client should build"); let extractor = client .extractor::
(providers::openai::GPT_4O_MINI) .build(); @@ -131,7 +131,7 @@ async fn extract_with_chat_history_with_usage_works() -> Result<()> { #[tokio::test] #[ignore = "This requires an API key"] async fn extract_and_extract_with_usage_return_same_data() -> Result<()> { - let client = providers::openai::Client::from_env(); + let client = providers::openai::Client::from_env().expect("client should build"); let extractor = client .extractor::(providers::openai::GPT_4O_MINI) .build(); @@ -161,7 +161,7 @@ async fn extract_and_extract_with_usage_return_same_data() -> Result<()> { #[tokio::test] #[ignore = "This requires an API key"] async fn usage_tracking_works_for_different_schemas() -> Result<()> { - let client = providers::openai::Client::from_env(); + let client = providers::openai::Client::from_env().expect("client should build"); // Test with simple schema let person_extractor = client diff --git a/rig/rig-core/tests/openai/image_generation.rs b/rig/rig-core/tests/openai/image_generation.rs index 3a6fa3f29..4a0d06e39 100644 --- a/rig/rig-core/tests/openai/image_generation.rs +++ b/rig/rig-core/tests/openai/image_generation.rs @@ -10,7 +10,7 @@ use crate::support::{IMAGE_PROMPT, assert_nonempty_bytes}; #[tokio::test] #[ignore = "requires OPENAI_API_KEY"] async fn image_generation_smoke() { - let client = openai::Client::from_env(); + let client = openai::Client::from_env().expect("client should build"); let model = client.image_generation_model(openai::DALL_E_2); let response = model diff --git a/rig/rig-core/tests/openai/models.rs b/rig/rig-core/tests/openai/models.rs index e2a298b63..2fa9eb4a8 100644 --- a/rig/rig-core/tests/openai/models.rs +++ b/rig/rig-core/tests/openai/models.rs @@ -6,7 +6,7 @@ use rig::providers::openai; #[tokio::test] #[ignore = "requires OPENAI_API_KEY"] async fn list_models_smoke() { - let client = openai::Client::from_env(); + let client = openai::Client::from_env().expect("client should build"); let models = match client.list_models().await { Ok(models) => models, Err(error) => { diff --git a/rig/rig-core/tests/openai/multi_extract.rs b/rig/rig-core/tests/openai/multi_extract.rs index 6e43c733c..6796159db 100644 --- a/rig/rig-core/tests/openai/multi_extract.rs +++ b/rig/rig-core/tests/openai/multi_extract.rs @@ -29,7 +29,7 @@ struct Sentiment { #[tokio::test] #[ignore = "requires OPENAI_API_KEY"] async fn batch_multi_extract_chain() -> Result<()> { - let client = openai::Client::from_env(); + let client = openai::Client::from_env().expect("client should build"); let names_extractor = client .extractor::(openai::GPT_4O_MINI) .preamble("Extract names from the given text.") diff --git a/rig/rig-core/tests/openai/permission_control.rs b/rig/rig-core/tests/openai/permission_control.rs index 9ecc9f4e9..8b2f1430e 100644 --- a/rig/rig-core/tests/openai/permission_control.rs +++ b/rig/rig-core/tests/openai/permission_control.rs @@ -152,6 +152,7 @@ async fn permission_control_prompt_example() -> Result<()> { let _cleanup = FileCleanup::new()?; let agent = providers::openai::Client::from_env() + .expect("client should build") .agent(providers::openai::GPT_4O_MINI) .preamble("You are a helpful assistant that can read files using different methods.") .tool(ReadFileHead) @@ -186,6 +187,7 @@ async fn permission_control_streaming_example() -> Result<()> { let _cleanup = FileCleanup::new()?; let agent = providers::openai::Client::from_env() + .expect("client should build") .agent(providers::openai::GPT_4O_MINI) .preamble("You are a helpful assistant that can read files using different methods.") .tool(ReadFileHead) diff --git a/rig/rig-core/tests/openai/reasoning_roundtrip.rs b/rig/rig-core/tests/openai/reasoning_roundtrip.rs index 2cc911487..8830142ff 100644 --- a/rig/rig-core/tests/openai/reasoning_roundtrip.rs +++ b/rig/rig-core/tests/openai/reasoning_roundtrip.rs @@ -11,7 +11,7 @@ use crate::reasoning::{self, ReasoningRoundtripAgent}; #[tokio::test] #[ignore = "requires OPENAI_API_KEY"] async fn streaming() { - let client = openai::Client::from_env(); + let client = openai::Client::from_env().expect("client should build"); reasoning::run_reasoning_roundtrip_streaming(ReasoningRoundtripAgent::new( client.completion_model("gpt-5.2"), Some(serde_json::json!({ @@ -24,7 +24,7 @@ async fn streaming() { #[tokio::test] #[ignore = "requires OPENAI_API_KEY"] async fn nonstreaming() { - let client = openai::Client::from_env(); + let client = openai::Client::from_env().expect("client should build"); reasoning::run_reasoning_roundtrip_nonstreaming(ReasoningRoundtripAgent::new( client.completion_model("gpt-5.2"), Some(serde_json::json!({ diff --git a/rig/rig-core/tests/openai/reasoning_tool_roundtrip.rs b/rig/rig-core/tests/openai/reasoning_tool_roundtrip.rs index 3b4f5048a..7678fddb9 100644 --- a/rig/rig-core/tests/openai/reasoning_tool_roundtrip.rs +++ b/rig/rig-core/tests/openai/reasoning_tool_roundtrip.rs @@ -17,7 +17,7 @@ use crate::reasoning::{self, WeatherTool}; #[ignore = "requires OPENAI_API_KEY"] async fn streaming() { let call_count = Arc::new(AtomicUsize::new(0)); - let client = openai::Client::from_env(); + let client = openai::Client::from_env().expect("client should build"); let agent = client .agent("gpt-5.2") .preamble(reasoning::TOOL_SYSTEM_PROMPT) @@ -49,7 +49,7 @@ async fn streaming() { #[ignore = "requires OPENAI_API_KEY"] async fn nonstreaming() { let call_count = Arc::new(AtomicUsize::new(0)); - let client = openai::Client::from_env(); + let client = openai::Client::from_env().expect("client should build"); let agent = client .agent("gpt-5.2") .preamble(reasoning::TOOL_SYSTEM_PROMPT) diff --git a/rig/rig-core/tests/openai/request_hook.rs b/rig/rig-core/tests/openai/request_hook.rs index 093876e67..e7ffd6c7c 100644 --- a/rig/rig-core/tests/openai/request_hook.rs +++ b/rig/rig-core/tests/openai/request_hook.rs @@ -69,6 +69,7 @@ where #[ignore = "requires OPENAI_API_KEY"] async fn request_hook_records_prompt_and_response() -> Result<()> { let agent = openai::Client::from_env() + .expect("client should build") .agent(openai::GPT_4O) .preamble("You are a comedian here to entertain the user using humour and jokes.") .build(); diff --git a/rig/rig-core/tests/openai/streaming.rs b/rig/rig-core/tests/openai/streaming.rs index 9a97b56cc..4a04df623 100644 --- a/rig/rig-core/tests/openai/streaming.rs +++ b/rig/rig-core/tests/openai/streaming.rs @@ -11,7 +11,7 @@ use crate::support::{ #[tokio::test] #[ignore = "requires OPENAI_API_KEY"] async fn streaming_smoke() { - let client = openai::Client::from_env(); + let client = openai::Client::from_env().expect("client should build"); let agent = client .agent(openai::GPT_4O) .preamble(STREAMING_PREAMBLE) @@ -28,7 +28,7 @@ async fn streaming_smoke() { #[tokio::test] #[ignore = "requires OPENAI_API_KEY"] async fn example_streaming_prompt() { - let client = openai::Client::from_env(); + let client = openai::Client::from_env().expect("client should build"); let agent = client .agent(openai::GPT_4O) .preamble("Be precise and concise.") diff --git a/rig/rig-core/tests/openai/streaming_tools.rs b/rig/rig-core/tests/openai/streaming_tools.rs index 1b5edd26c..c41db96df 100644 --- a/rig/rig-core/tests/openai/streaming_tools.rs +++ b/rig/rig-core/tests/openai/streaming_tools.rs @@ -19,7 +19,7 @@ use crate::support::{ #[tokio::test] #[ignore = "requires OPENAI_API_KEY"] async fn streaming_tools_smoke() { - let client = openai::Client::from_env(); + let client = openai::Client::from_env().expect("client should build"); let agent = client .agent(openai::GPT_4O) .preamble(STREAMING_TOOLS_PREAMBLE) @@ -38,7 +38,7 @@ async fn streaming_tools_smoke() { #[tokio::test] #[ignore = "requires OPENAI_API_KEY"] async fn example_streaming_with_tools() { - let client = openai::Client::from_env(); + let client = openai::Client::from_env().expect("client should build"); let agent = client .agent(openai::GPT_4O) .preamble( @@ -61,7 +61,7 @@ async fn example_streaming_with_tools() { #[tokio::test] #[ignore = "requires OPENAI_API_KEY"] async fn responses_stream_preserves_tool_result_flow() { - let client = openai::Client::from_env(); + let client = openai::Client::from_env().expect("client should build"); let agent = client .agent(openai::GPT_4O) .preamble(ORDERED_TOOL_STREAM_PREAMBLE) @@ -84,7 +84,7 @@ async fn responses_stream_preserves_tool_result_flow() { #[tokio::test] #[ignore = "requires OPENAI_API_KEY"] async fn raw_responses_stream_preserves_tool_then_followup_text_ordering() { - let client = openai::Client::from_env(); + let client = openai::Client::from_env().expect("client should build"); let model = client.completion_model(openai::GPT_4O); let request = model .completion_request(ORDERED_TOOL_STREAM_PROMPT) diff --git a/rig/rig-core/tests/openai/structured_output.rs b/rig/rig-core/tests/openai/structured_output.rs index e0a7035a3..399834ab9 100644 --- a/rig/rig-core/tests/openai/structured_output.rs +++ b/rig/rig-core/tests/openai/structured_output.rs @@ -47,7 +47,7 @@ fn assert_weather_forecast(forecast: &WeatherForecast, expected_city: &[&str]) { #[tokio::test] #[ignore = "requires OPENAI_API_KEY"] async fn structured_output_smoke() { - let client = openai::Client::from_env(); + let client = openai::Client::from_env().expect("client should build"); let agent = client.agent(openai::GPT_4O).build(); let response: SmokeStructuredOutput = agent @@ -61,7 +61,7 @@ async fn structured_output_smoke() { #[tokio::test] #[ignore = "requires OPENAI_API_KEY"] async fn prompt_typed_and_output_schema() { - let client = openai::Client::from_env(); + let client = openai::Client::from_env().expect("client should build"); let agent = client .agent(openai::GPT_4O) .preamble("You are a helpful weather assistant. Respond with realistic weather data.") diff --git a/rig/rig-core/tests/openai/transcription.rs b/rig/rig-core/tests/openai/transcription.rs index fdd4782c7..d564d43f1 100644 --- a/rig/rig-core/tests/openai/transcription.rs +++ b/rig/rig-core/tests/openai/transcription.rs @@ -10,7 +10,7 @@ use crate::support::{AUDIO_FIXTURE_PATH, assert_nonempty_response}; #[tokio::test] #[ignore = "requires OPENAI_API_KEY"] async fn transcription_smoke() { - let client = openai::Client::from_env(); + let client = openai::Client::from_env().expect("client should build"); let model = client.transcription_model(openai::WHISPER_1); let response = model .transcription_request() diff --git a/rig/rig-core/tests/openai/websocket.rs b/rig/rig-core/tests/openai/websocket.rs index 70f3b0654..ac6f263f2 100644 --- a/rig/rig-core/tests/openai/websocket.rs +++ b/rig/rig-core/tests/openai/websocket.rs @@ -24,7 +24,7 @@ fn extract_text(choice: &rig::OneOrMany) -> String { #[tokio::test] #[ignore = "requires OPENAI_API_KEY and --features websocket"] async fn websocket_session_roundtrip() -> Result<()> { - let client = openai::Client::from_env(); + let client = openai::Client::from_env().expect("client should build"); let model_name = openai::GPT_4O_MINI; let model = client.completion_model(model_name); let mut session = client.responses_websocket(model_name).await?; diff --git a/rig/rig-core/tests/openrouter.rs b/rig/rig-core/tests/openrouter.rs index 11119b41d..98580c14c 100644 --- a/rig/rig-core/tests/openrouter.rs +++ b/rig/rig-core/tests/openrouter.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + //! OpenRouter integration tests. //! //! Run the full provider target with: diff --git a/rig/rig-core/tests/openrouter/agent.rs b/rig/rig-core/tests/openrouter/agent.rs index f4cf57393..ec8913b47 100644 --- a/rig/rig-core/tests/openrouter/agent.rs +++ b/rig/rig-core/tests/openrouter/agent.rs @@ -11,7 +11,7 @@ use super::DEFAULT_MODEL; #[tokio::test] #[ignore = "requires OPENROUTER_API_KEY"] async fn completion_smoke() { - let client = openrouter::Client::from_env(); + let client = openrouter::Client::from_env().expect("client should build"); let agent = client.agent(DEFAULT_MODEL).preamble(BASIC_PREAMBLE).build(); let response = agent diff --git a/rig/rig-core/tests/openrouter/extractor.rs b/rig/rig-core/tests/openrouter/extractor.rs index ac0b8c5b2..a707cdef7 100644 --- a/rig/rig-core/tests/openrouter/extractor.rs +++ b/rig/rig-core/tests/openrouter/extractor.rs @@ -10,7 +10,7 @@ use super::DEFAULT_MODEL; #[tokio::test] #[ignore = "requires OPENROUTER_API_KEY"] async fn extractor_smoke() { - let client = openrouter::Client::from_env(); + let client = openrouter::Client::from_env().expect("client should build"); let extractor = client.extractor::(DEFAULT_MODEL).build(); let response = extractor diff --git a/rig/rig-core/tests/openrouter/extractor_usage.rs b/rig/rig-core/tests/openrouter/extractor_usage.rs index 3df144659..2bd3bdd1e 100644 --- a/rig/rig-core/tests/openrouter/extractor_usage.rs +++ b/rig/rig-core/tests/openrouter/extractor_usage.rs @@ -41,7 +41,7 @@ fn assert_compatible_professions(left: Option<&str>, right: &str) { #[tokio::test] #[ignore = "requires OPENROUTER_API_KEY"] async fn extract_backward_compatibility() -> Result<()> { - let client = openrouter::Client::from_env(); + let client = openrouter::Client::from_env().expect("client should build"); let extractor = client.extractor::(DEFAULT_MODEL).build(); let person = extractor @@ -58,7 +58,7 @@ async fn extract_backward_compatibility() -> Result<()> { #[tokio::test] #[ignore = "requires OPENROUTER_API_KEY"] async fn extract_with_usage_returns_data_and_usage() -> Result<()> { - let client = openrouter::Client::from_env(); + let client = openrouter::Client::from_env().expect("client should build"); let extractor = client.extractor::(DEFAULT_MODEL).build(); let response: ExtractionResponse = extractor @@ -78,7 +78,7 @@ async fn extract_with_usage_returns_data_and_usage() -> Result<()> { #[tokio::test] #[ignore = "requires OPENROUTER_API_KEY"] async fn extract_with_chat_history_with_usage_works() -> Result<()> { - let client = openrouter::Client::from_env(); + let client = openrouter::Client::from_env().expect("client should build"); let extractor = client.extractor::
(DEFAULT_MODEL).build(); let chat_history = vec![Message::user( @@ -105,7 +105,7 @@ async fn extract_with_chat_history_with_usage_works() -> Result<()> { #[tokio::test] #[ignore = "requires OPENROUTER_API_KEY"] async fn extract_and_extract_with_usage_return_same_data() -> Result<()> { - let client = openrouter::Client::from_env(); + let client = openrouter::Client::from_env().expect("client should build"); let extractor = client.extractor::(DEFAULT_MODEL).build(); let text = "Bob Johnson is a 55 year old retired teacher."; @@ -126,7 +126,7 @@ async fn extract_and_extract_with_usage_return_same_data() -> Result<()> { #[tokio::test] #[ignore = "requires OPENROUTER_API_KEY"] async fn usage_tracking_works_for_different_schemas() -> Result<()> { - let client = openrouter::Client::from_env(); + let client = openrouter::Client::from_env().expect("client should build"); let person_extractor = client.extractor::(DEFAULT_MODEL).build(); let person_response = person_extractor diff --git a/rig/rig-core/tests/openrouter/models.rs b/rig/rig-core/tests/openrouter/models.rs index 7153eabf7..11f600b2c 100644 --- a/rig/rig-core/tests/openrouter/models.rs +++ b/rig/rig-core/tests/openrouter/models.rs @@ -9,7 +9,7 @@ use rig::providers::openrouter; #[tokio::test] #[ignore = "requires OPENROUTER_API_KEY"] async fn list_models_smoke() { - let client = openrouter::Client::from_env(); + let client = openrouter::Client::from_env().expect("client should build"); let models = match client.list_models().await { Ok(models) => models, Err(error) => { diff --git a/rig/rig-core/tests/openrouter/multi_extract.rs b/rig/rig-core/tests/openrouter/multi_extract.rs index cd2751fcf..671247b8e 100644 --- a/rig/rig-core/tests/openrouter/multi_extract.rs +++ b/rig/rig-core/tests/openrouter/multi_extract.rs @@ -31,7 +31,7 @@ struct Sentiment { #[tokio::test] #[ignore = "requires OPENROUTER_API_KEY"] async fn batch_multi_extract_chain() -> Result<()> { - let client = openrouter::Client::from_env(); + let client = openrouter::Client::from_env().expect("client should build"); let names_extractor = client .extractor::(DEFAULT_MODEL) .preamble("Extract names from the given text.") diff --git a/rig/rig-core/tests/openrouter/multimodal.rs b/rig/rig-core/tests/openrouter/multimodal.rs index 118a94cfa..ca5cf5144 100644 --- a/rig/rig-core/tests/openrouter/multimodal.rs +++ b/rig/rig-core/tests/openrouter/multimodal.rs @@ -35,7 +35,7 @@ fn pdf_document() -> Document { #[tokio::test] #[ignore = "requires OPENROUTER_API_KEY"] async fn image_analysis_prompt() { - let client = openrouter::Client::from_env(); + let client = openrouter::Client::from_env().expect("client should build"); let agent = client .agent(VISION_MODEL) .preamble("You are a helpful assistant that describes images in detail.") @@ -58,7 +58,7 @@ async fn image_analysis_prompt() { #[tokio::test] #[ignore = "requires OPENROUTER_API_KEY"] async fn pdf_analysis_prompt() { - let client = openrouter::Client::from_env(); + let client = openrouter::Client::from_env().expect("client should build"); let agent = client .agent(VISION_MODEL) .preamble("You are a helpful assistant that summarizes documents.") @@ -81,7 +81,7 @@ async fn pdf_analysis_prompt() { #[tokio::test] #[ignore = "requires OPENROUTER_API_KEY"] async fn mixed_multimodal_prompt() { - let client = openrouter::Client::from_env(); + let client = openrouter::Client::from_env().expect("client should build"); let agent = client .agent(VISION_MODEL) .preamble("You are a helpful assistant.") diff --git a/rig/rig-core/tests/openrouter/permission_control.rs b/rig/rig-core/tests/openrouter/permission_control.rs index 27717eb20..d47ad91ca 100644 --- a/rig/rig-core/tests/openrouter/permission_control.rs +++ b/rig/rig-core/tests/openrouter/permission_control.rs @@ -156,6 +156,7 @@ async fn permission_control_prompt_example() -> Result<()> { let _cleanup = FileCleanup::new()?; let agent = openrouter::Client::from_env() + .expect("client should build") .agent(TOOL_MODEL) .preamble("You are a helpful assistant that can read files using different methods.") .tool(ReadFileHead) @@ -191,6 +192,7 @@ async fn permission_control_streaming_example() -> Result<()> { let _cleanup = FileCleanup::new()?; let agent = openrouter::Client::from_env() + .expect("client should build") .agent(TOOL_MODEL) .preamble("You are a helpful assistant that can read files using different methods.") .tool(ReadFileHead) diff --git a/rig/rig-core/tests/openrouter/provider_selection.rs b/rig/rig-core/tests/openrouter/provider_selection.rs index 2270418c9..68c0af111 100644 --- a/rig/rig-core/tests/openrouter/provider_selection.rs +++ b/rig/rig-core/tests/openrouter/provider_selection.rs @@ -11,7 +11,7 @@ const DEEPSEEK_V3_2: &str = "deepseek/deepseek-v3.2"; #[tokio::test] #[ignore = "requires OPENROUTER_API_KEY"] async fn provider_selection_scenarios() { - let client = openrouter::Client::from_env(); + let client = openrouter::Client::from_env().expect("client should build"); let scenarios = [ ( "hello", diff --git a/rig/rig-core/tests/openrouter/reasoning_roundtrip.rs b/rig/rig-core/tests/openrouter/reasoning_roundtrip.rs index dedfaea64..60ce51fc1 100644 --- a/rig/rig-core/tests/openrouter/reasoning_roundtrip.rs +++ b/rig/rig-core/tests/openrouter/reasoning_roundtrip.rs @@ -11,7 +11,7 @@ use crate::reasoning::{self, ReasoningRoundtripAgent}; #[tokio::test] #[ignore = "requires OPENROUTER_API_KEY"] async fn streaming() { - let client = openrouter::Client::from_env(); + let client = openrouter::Client::from_env().expect("client should build"); reasoning::run_reasoning_roundtrip_streaming(ReasoningRoundtripAgent::new( client.completion_model("openai/gpt-5.2"), Some(serde_json::json!({ @@ -25,7 +25,7 @@ async fn streaming() { #[tokio::test] #[ignore = "requires OPENROUTER_API_KEY"] async fn nonstreaming() { - let client = openrouter::Client::from_env(); + let client = openrouter::Client::from_env().expect("client should build"); reasoning::run_reasoning_roundtrip_nonstreaming(ReasoningRoundtripAgent::new( client.completion_model("openai/gpt-5.2"), Some(serde_json::json!({ diff --git a/rig/rig-core/tests/openrouter/reasoning_tool_roundtrip.rs b/rig/rig-core/tests/openrouter/reasoning_tool_roundtrip.rs index bc69cc927..9858b71fd 100644 --- a/rig/rig-core/tests/openrouter/reasoning_tool_roundtrip.rs +++ b/rig/rig-core/tests/openrouter/reasoning_tool_roundtrip.rs @@ -17,7 +17,7 @@ use crate::reasoning::{self, WeatherTool}; #[ignore = "requires OPENROUTER_API_KEY"] async fn streaming() { let call_count = Arc::new(AtomicUsize::new(0)); - let client = openrouter::Client::from_env(); + let client = openrouter::Client::from_env().expect("client should build"); let agent = client .agent("openai/gpt-5.2") .preamble(reasoning::TOOL_SYSTEM_PROMPT) @@ -42,7 +42,7 @@ async fn streaming() { #[ignore = "requires OPENROUTER_API_KEY"] async fn nonstreaming() { let call_count = Arc::new(AtomicUsize::new(0)); - let client = openrouter::Client::from_env(); + let client = openrouter::Client::from_env().expect("client should build"); let agent = client .agent("openai/gpt-5.2") .preamble(reasoning::TOOL_SYSTEM_PROMPT) diff --git a/rig/rig-core/tests/openrouter/request_hook.rs b/rig/rig-core/tests/openrouter/request_hook.rs index 320b5d2e0..d33df35ee 100644 --- a/rig/rig-core/tests/openrouter/request_hook.rs +++ b/rig/rig-core/tests/openrouter/request_hook.rs @@ -71,6 +71,7 @@ where #[ignore = "requires OPENROUTER_API_KEY"] async fn request_hook_records_prompt_and_response() -> Result<()> { let agent = openrouter::Client::from_env() + .expect("client should build") .agent(DEFAULT_MODEL) .preamble("You are a comedian here to entertain the user using humour and jokes.") .build(); diff --git a/rig/rig-core/tests/openrouter/streaming.rs b/rig/rig-core/tests/openrouter/streaming.rs index 9cfd17a1f..235461fe1 100644 --- a/rig/rig-core/tests/openrouter/streaming.rs +++ b/rig/rig-core/tests/openrouter/streaming.rs @@ -13,7 +13,7 @@ use super::DEFAULT_MODEL; #[tokio::test] #[ignore = "requires OPENROUTER_API_KEY"] async fn streaming_smoke() { - let client = openrouter::Client::from_env(); + let client = openrouter::Client::from_env().expect("client should build"); let agent = client .agent(DEFAULT_MODEL) .preamble(STREAMING_PREAMBLE) @@ -30,7 +30,7 @@ async fn streaming_smoke() { #[tokio::test] #[ignore = "requires OPENROUTER_API_KEY"] async fn example_streaming_prompt() { - let client = openrouter::Client::from_env(); + let client = openrouter::Client::from_env().expect("client should build"); let agent = client .agent(DEFAULT_MODEL) .preamble("Be precise and concise.") diff --git a/rig/rig-core/tests/openrouter/streaming_tools.rs b/rig/rig-core/tests/openrouter/streaming_tools.rs index f611a3543..081c68a43 100644 --- a/rig/rig-core/tests/openrouter/streaming_tools.rs +++ b/rig/rig-core/tests/openrouter/streaming_tools.rs @@ -25,7 +25,7 @@ use super::TOOL_MODEL; #[tokio::test] #[ignore = "requires OPENROUTER_API_KEY"] async fn streaming_tools_smoke() { - let client = openrouter::Client::from_env(); + let client = openrouter::Client::from_env().expect("client should build"); let agent = client .agent(openrouter::GEMINI_FLASH_2_0) .preamble(STREAMING_TOOLS_PREAMBLE) @@ -44,7 +44,7 @@ async fn streaming_tools_smoke() { #[tokio::test] #[ignore = "requires OPENROUTER_API_KEY"] async fn raw_stream_decorates_reasoning_tool_call_metadata() { - let client = openrouter::Client::from_env(); + let client = openrouter::Client::from_env().expect("client should build"); let model = client.completion_model("openai/o4-mini"); let tool_definition = WeatherTool::new(Arc::new(AtomicUsize::new(0))) .definition(String::new()) @@ -92,7 +92,7 @@ async fn raw_stream_decorates_reasoning_tool_call_metadata() { #[tokio::test] #[ignore = "requires OPENROUTER_API_KEY"] async fn raw_stream_surfaces_two_distinct_tool_calls_before_text() { - let client = openrouter::Client::from_env(); + let client = openrouter::Client::from_env().expect("client should build"); let model = client.completion_model(TOOL_MODEL); let request = model .completion_request(TWO_TOOL_STREAM_PROMPT) @@ -118,7 +118,7 @@ async fn raw_stream_surfaces_two_distinct_tool_calls_before_text() { #[tokio::test] #[ignore = "requires OPENROUTER_API_KEY"] async fn raw_followup_uses_tool_result_without_new_tool_calls() { - let client = openrouter::Client::from_env(); + let client = openrouter::Client::from_env().expect("client should build"); let model = client.completion_model(TOOL_MODEL); let request = model .completion_request(ORDERED_TOOL_STREAM_PROMPT) diff --git a/rig/rig-core/tests/openrouter/typed_prompt_tools.rs b/rig/rig-core/tests/openrouter/typed_prompt_tools.rs index 122312965..2a57d71e9 100644 --- a/rig/rig-core/tests/openrouter/typed_prompt_tools.rs +++ b/rig/rig-core/tests/openrouter/typed_prompt_tools.rs @@ -77,7 +77,7 @@ impl Tool for WeatherTool { #[ignore = "requires OPENROUTER_API_KEY"] async fn prompt_typed_with_tool_call_roundtrip() -> Result<()> { let call_count = Arc::new(AtomicUsize::new(0)); - let client = openrouter::Client::from_env(); + let client = openrouter::Client::from_env().expect("client should build"); let agent = client .agent(TOOL_MODEL) .preamble( diff --git a/rig/rig-core/tests/perplexity.rs b/rig/rig-core/tests/perplexity.rs index 9b6028b96..b5752a112 100644 --- a/rig/rig-core/tests/perplexity.rs +++ b/rig/rig-core/tests/perplexity.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + //! Perplexity integration tests. //! //! Run the full provider target with: diff --git a/rig/rig-core/tests/perplexity/agent.rs b/rig/rig-core/tests/perplexity/agent.rs index ba2d4d3d4..034009b58 100644 --- a/rig/rig-core/tests/perplexity/agent.rs +++ b/rig/rig-core/tests/perplexity/agent.rs @@ -9,7 +9,7 @@ use crate::support::assert_nonempty_response; #[tokio::test] #[ignore = "requires PERPLEXITY_API_KEY"] async fn completion_smoke() { - let client = perplexity::Client::from_env(); + let client = perplexity::Client::from_env().expect("client should build"); let agent = client .agent(SONAR) .preamble("Be precise and concise.") diff --git a/rig/rig-core/tests/together.rs b/rig/rig-core/tests/together.rs index a94c3c51d..dfac1457f 100644 --- a/rig/rig-core/tests/together.rs +++ b/rig/rig-core/tests/together.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + //! Together integration tests. //! //! Run the full provider target with: diff --git a/rig/rig-core/tests/together/agent.rs b/rig/rig-core/tests/together/agent.rs index b28bd2257..5b6e3c75d 100644 --- a/rig/rig-core/tests/together/agent.rs +++ b/rig/rig-core/tests/together/agent.rs @@ -9,7 +9,7 @@ use crate::support::{BASIC_PREAMBLE, BASIC_PROMPT, assert_nonempty_response}; #[tokio::test] #[ignore = "requires TOGETHER_API_KEY"] async fn completion_smoke() { - let client = together::Client::from_env(); + let client = together::Client::from_env().expect("client should build"); let agent = client .agent(together::MIXTRAL_8X7B_INSTRUCT_V0_1) .preamble(BASIC_PREAMBLE) diff --git a/rig/rig-core/tests/together/context.rs b/rig/rig-core/tests/together/context.rs index 4729345c0..3ebb966cf 100644 --- a/rig/rig-core/tests/together/context.rs +++ b/rig/rig-core/tests/together/context.rs @@ -9,7 +9,7 @@ use crate::support::{CONTEXT_DOCS, CONTEXT_PROMPT, assert_contains_any_case_inse #[tokio::test] #[ignore = "requires TOGETHER_API_KEY"] async fn context_smoke() { - let client = together::Client::from_env(); + let client = together::Client::from_env().expect("client should build"); let agent = CONTEXT_DOCS .iter() .copied() diff --git a/rig/rig-core/tests/together/embeddings.rs b/rig/rig-core/tests/together/embeddings.rs index 15d17dc8e..f9656c8df 100644 --- a/rig/rig-core/tests/together/embeddings.rs +++ b/rig/rig-core/tests/together/embeddings.rs @@ -9,7 +9,7 @@ use crate::support::{EMBEDDING_INPUTS, assert_embeddings_nonempty_and_consistent #[tokio::test] #[ignore = "requires TOGETHER_API_KEY"] async fn embeddings_smoke() { - let client = together::Client::from_env(); + let client = together::Client::from_env().expect("client should build"); let model = client.embedding_model(together::embedding::M2_BERT_80M_8K_RETRIEVAL); let embeddings = model diff --git a/rig/rig-core/tests/together/streaming.rs b/rig/rig-core/tests/together/streaming.rs index c557dd065..ccf032bc6 100644 --- a/rig/rig-core/tests/together/streaming.rs +++ b/rig/rig-core/tests/together/streaming.rs @@ -11,7 +11,7 @@ use crate::support::{ #[tokio::test] #[ignore = "requires TOGETHER_API_KEY"] async fn streaming_smoke() { - let client = together::Client::from_env(); + let client = together::Client::from_env().expect("client should build"); let agent = client .agent(together::LLAMA_3_8B_CHAT_HF) .preamble(STREAMING_PREAMBLE) diff --git a/rig/rig-core/tests/together/streaming_tools.rs b/rig/rig-core/tests/together/streaming_tools.rs index 1751cfce2..7e7cd5a8c 100644 --- a/rig/rig-core/tests/together/streaming_tools.rs +++ b/rig/rig-core/tests/together/streaming_tools.rs @@ -12,7 +12,7 @@ use crate::support::{ #[tokio::test] #[ignore = "requires TOGETHER_API_KEY"] async fn streaming_tools_smoke() { - let client = together::Client::from_env(); + let client = together::Client::from_env().expect("client should build"); let agent = client .agent(together::LLAMA_2_70B_CHAT_TOGETHER) .preamble(STREAMING_TOOLS_PREAMBLE) diff --git a/rig/rig-core/tests/together/tools.rs b/rig/rig-core/tests/together/tools.rs index 34449ede3..8cd91762f 100644 --- a/rig/rig-core/tests/together/tools.rs +++ b/rig/rig-core/tests/together/tools.rs @@ -11,7 +11,7 @@ use crate::support::{ #[tokio::test] #[ignore = "requires TOGETHER_API_KEY"] async fn tools_smoke() { - let client = together::Client::from_env(); + let client = together::Client::from_env().expect("client should build"); let agent = client .agent(together::MIXTRAL_8X7B_INSTRUCT_V0_1) .preamble(TOOLS_PREAMBLE) diff --git a/rig/rig-core/tests/voyageai.rs b/rig/rig-core/tests/voyageai.rs index 376ef7f32..41bd5b0ab 100644 --- a/rig/rig-core/tests/voyageai.rs +++ b/rig/rig-core/tests/voyageai.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + //! VoyageAI integration tests. //! //! Run the provider target with: diff --git a/rig/rig-core/tests/voyageai/embeddings.rs b/rig/rig-core/tests/voyageai/embeddings.rs index 9808c0a00..7cc3c2430 100644 --- a/rig/rig-core/tests/voyageai/embeddings.rs +++ b/rig/rig-core/tests/voyageai/embeddings.rs @@ -9,7 +9,7 @@ use crate::support::{EMBEDDING_INPUTS, assert_embeddings_nonempty_and_consistent #[tokio::test] #[ignore = "requires VOYAGE_API_KEY"] async fn embeddings_smoke() { - let client = voyageai::Client::from_env(); + let client = voyageai::Client::from_env().expect("client should build"); let model = client.embedding_model(voyageai::VOYAGE_3_LARGE); let embeddings = model diff --git a/rig/rig-core/tests/xai.rs b/rig/rig-core/tests/xai.rs index db1fb50c3..4b974266b 100644 --- a/rig/rig-core/tests/xai.rs +++ b/rig/rig-core/tests/xai.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + //! xAI integration tests. //! //! Run the full provider target with: diff --git a/rig/rig-core/tests/xai/agent.rs b/rig/rig-core/tests/xai/agent.rs index e65ffc328..2e48b8acb 100644 --- a/rig/rig-core/tests/xai/agent.rs +++ b/rig/rig-core/tests/xai/agent.rs @@ -9,7 +9,7 @@ use crate::support::{BASIC_PREAMBLE, BASIC_PROMPT, assert_nonempty_response}; #[tokio::test] #[ignore = "requires XAI_API_KEY"] async fn completion_smoke() { - let client = xai::Client::from_env(); + let client = xai::Client::from_env().expect("client should build"); let agent = client .agent(xai::completion::GROK_3_MINI) .preamble(BASIC_PREAMBLE) diff --git a/rig/rig-core/tests/xai/audio_generation.rs b/rig/rig-core/tests/xai/audio_generation.rs index 4ecb2b503..a454c5d3e 100644 --- a/rig/rig-core/tests/xai/audio_generation.rs +++ b/rig/rig-core/tests/xai/audio_generation.rs @@ -11,7 +11,7 @@ use crate::support::{AUDIO_TEXT, assert_nonempty_bytes}; #[tokio::test] #[ignore = "requires XAI_API_KEY"] async fn audio_generation_smoke() { - let client = xai::Client::from_env(); + let client = xai::Client::from_env().expect("client should build"); let model = client.audio_generation_model(xai::TTS_1); let response = model diff --git a/rig/rig-core/tests/xai/context.rs b/rig/rig-core/tests/xai/context.rs index 9ea885e5e..239c99bd1 100644 --- a/rig/rig-core/tests/xai/context.rs +++ b/rig/rig-core/tests/xai/context.rs @@ -15,7 +15,7 @@ const XAI_CONTEXT_DOCS: [&str; 3] = [ #[tokio::test] #[ignore = "requires XAI_API_KEY"] async fn context_smoke() { - let client = xai::Client::from_env(); + let client = xai::Client::from_env().expect("client should build"); let agent = XAI_CONTEXT_DOCS .iter() .copied() diff --git a/rig/rig-core/tests/xai/extractor.rs b/rig/rig-core/tests/xai/extractor.rs index f1e94b445..6f1e8ceb2 100644 --- a/rig/rig-core/tests/xai/extractor.rs +++ b/rig/rig-core/tests/xai/extractor.rs @@ -8,7 +8,7 @@ use crate::support::{EXTRACTOR_TEXT, SmokePerson, assert_nonempty_response}; #[tokio::test] #[ignore = "requires XAI_API_KEY"] async fn extractor_smoke() { - let client = xai::Client::from_env(); + let client = xai::Client::from_env().expect("client should build"); let extractor = client.extractor::(xai::GROK_3_MINI).build(); let response = extractor diff --git a/rig/rig-core/tests/xai/extractor_usage.rs b/rig/rig-core/tests/xai/extractor_usage.rs index 4b25109b5..c7ef02180 100644 --- a/rig/rig-core/tests/xai/extractor_usage.rs +++ b/rig/rig-core/tests/xai/extractor_usage.rs @@ -39,7 +39,7 @@ fn assert_compatible_professions(left: Option<&str>, right: &str) { #[tokio::test] #[ignore = "requires XAI_API_KEY"] async fn extract_backward_compatibility() -> Result<()> { - let client = xai::Client::from_env(); + let client = xai::Client::from_env().expect("client should build"); let extractor = client.extractor::(xai::GROK_3_MINI).build(); let person = extractor @@ -56,7 +56,7 @@ async fn extract_backward_compatibility() -> Result<()> { #[tokio::test] #[ignore = "requires XAI_API_KEY"] async fn extract_with_usage_returns_data_and_usage() -> Result<()> { - let client = xai::Client::from_env(); + let client = xai::Client::from_env().expect("client should build"); let extractor = client.extractor::(xai::GROK_3_MINI).build(); let response: ExtractionResponse = extractor @@ -76,7 +76,7 @@ async fn extract_with_usage_returns_data_and_usage() -> Result<()> { #[tokio::test] #[ignore = "requires XAI_API_KEY"] async fn extract_with_chat_history_with_usage_works() -> Result<()> { - let client = xai::Client::from_env(); + let client = xai::Client::from_env().expect("client should build"); let extractor = client.extractor::
(xai::GROK_3_MINI).build(); let chat_history = vec![Message::user( @@ -103,7 +103,7 @@ async fn extract_with_chat_history_with_usage_works() -> Result<()> { #[tokio::test] #[ignore = "requires XAI_API_KEY"] async fn extract_and_extract_with_usage_return_same_data() -> Result<()> { - let client = xai::Client::from_env(); + let client = xai::Client::from_env().expect("client should build"); let extractor = client.extractor::(xai::GROK_3_MINI).build(); let text = "Bob Johnson is a 55 year old retired teacher."; @@ -124,7 +124,7 @@ async fn extract_and_extract_with_usage_return_same_data() -> Result<()> { #[tokio::test] #[ignore = "requires XAI_API_KEY"] async fn usage_tracking_works_for_different_schemas() -> Result<()> { - let client = xai::Client::from_env(); + let client = xai::Client::from_env().expect("client should build"); let person_extractor = client.extractor::(xai::GROK_3_MINI).build(); let person_response = person_extractor diff --git a/rig/rig-core/tests/xai/image_generation.rs b/rig/rig-core/tests/xai/image_generation.rs index 229b29630..2bc928782 100644 --- a/rig/rig-core/tests/xai/image_generation.rs +++ b/rig/rig-core/tests/xai/image_generation.rs @@ -11,7 +11,7 @@ use crate::support::{IMAGE_PROMPT, assert_nonempty_bytes}; #[tokio::test] #[ignore = "requires XAI_API_KEY"] async fn image_generation_smoke() { - let client = xai::Client::from_env(); + let client = xai::Client::from_env().expect("client should build"); let model = client.image_generation_model(xai::image_generation::GROK_IMAGINE_IMAGE_PRO); let response = model diff --git a/rig/rig-core/tests/xai/loaders.rs b/rig/rig-core/tests/xai/loaders.rs index a758f1fee..529bd4b8a 100644 --- a/rig/rig-core/tests/xai/loaders.rs +++ b/rig/rig-core/tests/xai/loaders.rs @@ -10,7 +10,7 @@ use crate::support::{LOADERS_GLOB, LOADERS_PROMPT, assert_loader_answer_is_relev #[tokio::test] #[ignore = "requires XAI_API_KEY"] async fn loaders_smoke() { - let client = xai::Client::from_env(); + let client = xai::Client::from_env().expect("client should build"); let examples = FileLoader::with_glob(LOADERS_GLOB) .expect("examples glob should parse") .read_with_path() diff --git a/rig/rig-core/tests/xai/multi_extract.rs b/rig/rig-core/tests/xai/multi_extract.rs index 2021d657d..a35e4fb53 100644 --- a/rig/rig-core/tests/xai/multi_extract.rs +++ b/rig/rig-core/tests/xai/multi_extract.rs @@ -72,7 +72,7 @@ fn assert_sentiment_shape(extract: &CombinedExtract) { #[tokio::test] #[ignore = "requires XAI_API_KEY"] async fn batch_multi_extract_chain() -> Result<()> { - let client = xai::Client::from_env(); + let client = xai::Client::from_env().expect("client should build"); let names_extractor = client .extractor::(xai::GROK_3_MINI) .preamble("Extract names from the given text.") diff --git a/rig/rig-core/tests/xai/permission_control.rs b/rig/rig-core/tests/xai/permission_control.rs index ea1e0b281..626159d59 100644 --- a/rig/rig-core/tests/xai/permission_control.rs +++ b/rig/rig-core/tests/xai/permission_control.rs @@ -154,6 +154,7 @@ async fn permission_control_prompt_example() -> Result<()> { let _cleanup = FileCleanup::new()?; let agent = xai::Client::from_env() + .expect("client should build") .agent(xai::GROK_4) .preamble("You are a helpful assistant that can read files using different methods.") .tool(ReadFileHead) @@ -189,6 +190,7 @@ async fn permission_control_streaming_example() -> Result<()> { let _cleanup = FileCleanup::new()?; let agent = xai::Client::from_env() + .expect("client should build") .agent(xai::GROK_4) .preamble("You are a helpful assistant that can read files using different methods.") .tool(ReadFileHead) diff --git a/rig/rig-core/tests/xai/reasoning_roundtrip.rs b/rig/rig-core/tests/xai/reasoning_roundtrip.rs index 117817561..f9cf71b2d 100644 --- a/rig/rig-core/tests/xai/reasoning_roundtrip.rs +++ b/rig/rig-core/tests/xai/reasoning_roundtrip.rs @@ -11,7 +11,7 @@ use crate::reasoning::{self, ReasoningRoundtripAgent}; #[tokio::test] #[ignore = "requires XAI_API_KEY - validate with grok-4-0725 once key is available"] async fn streaming() { - let client = xai::Client::from_env(); + let client = xai::Client::from_env().expect("client should build"); reasoning::run_reasoning_roundtrip_streaming(ReasoningRoundtripAgent::new( client.completion_model(xai::GROK_3_MINI), None, @@ -22,7 +22,7 @@ async fn streaming() { #[tokio::test] #[ignore = "requires XAI_API_KEY - validate with grok-4-0725 once key is available"] async fn nonstreaming() { - let client = xai::Client::from_env(); + let client = xai::Client::from_env().expect("client should build"); reasoning::run_reasoning_roundtrip_nonstreaming(ReasoningRoundtripAgent::new( client.completion_model(xai::GROK_3_MINI), None, diff --git a/rig/rig-core/tests/xai/reasoning_tool_roundtrip.rs b/rig/rig-core/tests/xai/reasoning_tool_roundtrip.rs index c27c667a6..f47606718 100644 --- a/rig/rig-core/tests/xai/reasoning_tool_roundtrip.rs +++ b/rig/rig-core/tests/xai/reasoning_tool_roundtrip.rs @@ -17,7 +17,7 @@ use crate::reasoning::{self, WeatherTool}; #[ignore = "requires XAI_API_KEY - validate with grok-4-0725 once key is available"] async fn streaming() { let call_count = Arc::new(AtomicUsize::new(0)); - let client = xai::Client::from_env(); + let client = xai::Client::from_env().expect("client should build"); let agent = client .agent(xai::GROK_3_MINI) .preamble(reasoning::TOOL_SYSTEM_PROMPT) @@ -38,7 +38,7 @@ async fn streaming() { #[ignore = "requires XAI_API_KEY - validate with grok-4-0725 once key is available"] async fn nonstreaming() { let call_count = Arc::new(AtomicUsize::new(0)); - let client = xai::Client::from_env(); + let client = xai::Client::from_env().expect("client should build"); let agent = client .agent(xai::GROK_3_MINI) .preamble(reasoning::TOOL_SYSTEM_PROMPT) diff --git a/rig/rig-core/tests/xai/request_hook.rs b/rig/rig-core/tests/xai/request_hook.rs index c26d8b6b9..93e6a71db 100644 --- a/rig/rig-core/tests/xai/request_hook.rs +++ b/rig/rig-core/tests/xai/request_hook.rs @@ -69,6 +69,7 @@ where #[ignore = "requires XAI_API_KEY"] async fn request_hook_records_prompt_and_response() -> Result<()> { let agent = xai::Client::from_env() + .expect("client should build") .agent(xai::GROK_3_MINI) .preamble("You are a comedian here to entertain the user using humour and jokes.") .build(); diff --git a/rig/rig-core/tests/xai/streaming.rs b/rig/rig-core/tests/xai/streaming.rs index 2c26964b1..b45ca99d2 100644 --- a/rig/rig-core/tests/xai/streaming.rs +++ b/rig/rig-core/tests/xai/streaming.rs @@ -11,7 +11,7 @@ use crate::support::{ #[tokio::test] #[ignore = "requires XAI_API_KEY"] async fn streaming_smoke() { - let client = xai::Client::from_env(); + let client = xai::Client::from_env().expect("client should build"); let agent = client .agent(xai::completion::GROK_3_MINI) .preamble(STREAMING_PREAMBLE) diff --git a/rig/rig-core/tests/xai/streaming_tools.rs b/rig/rig-core/tests/xai/streaming_tools.rs index 31f6e0d00..d66204657 100644 --- a/rig/rig-core/tests/xai/streaming_tools.rs +++ b/rig/rig-core/tests/xai/streaming_tools.rs @@ -61,7 +61,7 @@ impl Tool for StatusWordTool { #[tokio::test] #[ignore = "requires XAI_API_KEY"] async fn raw_stream_emits_required_zero_arg_tool_call() { - let client = xai::Client::from_env(); + let client = xai::Client::from_env().expect("client should build"); let model = client.completion_model(xai::completion::GROK_4); let request = model .completion_request(REQUIRED_ZERO_ARG_TOOL_PROMPT) @@ -76,7 +76,7 @@ async fn raw_stream_emits_required_zero_arg_tool_call() { #[tokio::test] #[ignore = "requires XAI_API_KEY"] async fn responses_stream_preserves_tool_result_flow() { - let client = xai::Client::from_env(); + let client = xai::Client::from_env().expect("client should build"); let agent = client .agent(xai::completion::GROK_4) .preamble(XAI_STATUS_TOOL_PREAMBLE) @@ -99,7 +99,7 @@ async fn responses_stream_preserves_tool_result_flow() { #[tokio::test] #[ignore = "requires XAI_API_KEY"] async fn raw_responses_stream_preserves_tool_then_followup_text_ordering() { - let client = xai::Client::from_env(); + let client = xai::Client::from_env().expect("client should build"); let model = client.completion_model(xai::completion::GROK_4); let request = model .completion_request(XAI_STATUS_TOOL_PROMPT) diff --git a/rig/rig-core/tests/xai/tools.rs b/rig/rig-core/tests/xai/tools.rs index 4a26e4cf6..2c0e8118e 100644 --- a/rig/rig-core/tests/xai/tools.rs +++ b/rig/rig-core/tests/xai/tools.rs @@ -11,7 +11,7 @@ use crate::support::{ #[tokio::test] #[ignore = "requires XAI_API_KEY"] async fn tools_smoke() { - let client = xai::Client::from_env(); + let client = xai::Client::from_env().expect("client should build"); let agent = client .agent(xai::completion::GROK_3_MINI) .preamble(TOOLS_PREAMBLE) diff --git a/rig/rig-core/tests/xai/typed_prompt_tools.rs b/rig/rig-core/tests/xai/typed_prompt_tools.rs index 76ee67ae1..f2f3cc8b5 100644 --- a/rig/rig-core/tests/xai/typed_prompt_tools.rs +++ b/rig/rig-core/tests/xai/typed_prompt_tools.rs @@ -75,7 +75,7 @@ impl Tool for WeatherTool { #[ignore = "requires XAI_API_KEY"] async fn prompt_typed_with_tool_call_roundtrip() -> Result<()> { let call_count = Arc::new(AtomicUsize::new(0)); - let client = xai::Client::from_env(); + let client = xai::Client::from_env().expect("client should build"); let agent = client .agent(xai::GROK_4) .preamble( diff --git a/rig/rig-core/tests/zai.rs b/rig/rig-core/tests/zai.rs index b5d4809d7..553997490 100644 --- a/rig/rig-core/tests/zai.rs +++ b/rig/rig-core/tests/zai.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + //! Z.AI integration tests. //! //! Run the provider target with: diff --git a/rig/rig-derive/Cargo.toml b/rig/rig-derive/Cargo.toml index 85deeefd3..8c7c6c047 100644 --- a/rig/rig-derive/Cargo.toml +++ b/rig/rig-derive/Cargo.toml @@ -22,6 +22,7 @@ syn = { workspace = true, features = ["full"] } proc-macro = true [dev-dependencies] +anyhow = { workspace = true } rig-core = { path = "../rig-core" } serde = { workspace = true } serde_json = { workspace = true } diff --git a/rig/rig-derive/examples/rig_tool/async_tool.rs b/rig/rig-derive/examples/rig_tool/async_tool.rs index 27577eacb..e75f00710 100644 --- a/rig/rig-derive/examples/rig_tool/async_tool.rs +++ b/rig/rig-derive/examples/rig_tool/async_tool.rs @@ -25,10 +25,10 @@ async fn async_operation(input: String, delay_ms: u64) -> Result Result<(), anyhow::Error> { tracing_subscriber::fmt().pretty().init(); - let async_agent = providers::openai::Client::from_env() + let async_agent = providers::openai::Client::from_env()? .agent(providers::openai::GPT_4O) .preamble("You are an agent with tools access, always use the tools") .max_tokens(1024) @@ -38,7 +38,7 @@ async fn main() { println!("Tool definition:"); println!( "ASYNCOPERATION: {}", - serde_json::to_string_pretty(&AsyncOperation.definition(String::default()).await).unwrap() + serde_json::to_string_pretty(&AsyncOperation.definition(String::default()).await)? ); for prompt in [ @@ -49,6 +49,8 @@ async fn main() { "Process the text 'error handling' with a delay of 'not a number'", ] { println!("User: {prompt}"); - println!("Agent: {}", async_agent.prompt(prompt).await.unwrap()); + println!("Agent: {}", async_agent.prompt(prompt).await?); } + + Ok(()) } diff --git a/rig/rig-derive/examples/rig_tool/full.rs b/rig/rig-derive/examples/rig_tool/full.rs index d9c02aceb..403a1ba1a 100644 --- a/rig/rig-derive/examples/rig_tool/full.rs +++ b/rig/rig-derive/examples/rig_tool/full.rs @@ -29,10 +29,10 @@ fn string_processor(text: String, operation: String) -> Result Result<(), anyhow::Error> { tracing_subscriber::fmt().pretty().init(); - let string_agent = providers::openai::Client::from_env() + let string_agent = providers::openai::Client::from_env()? .agent(providers::openai::GPT_4O) .preamble("You are an agent with tools access, always use the tools") .max_tokens(1024) @@ -42,7 +42,7 @@ async fn main() { println!("Tool definition:"); println!( "STRINGPROCESSOR: {}", - serde_json::to_string_pretty(&StringProcessor.definition(String::default()).await).unwrap() + serde_json::to_string_pretty(&StringProcessor.definition(String::default()).await)? ); for prompt in [ @@ -54,6 +54,8 @@ async fn main() { "Perform an invalid operation on 'hello world'", ] { println!("User: {prompt}"); - println!("Agent: {}", string_agent.prompt(prompt).await.unwrap()); + println!("Agent: {}", string_agent.prompt(prompt).await?); } + + Ok(()) } diff --git a/rig/rig-derive/examples/rig_tool/simple.rs b/rig/rig-derive/examples/rig_tool/simple.rs index 0a8a6fd54..0d8a78a52 100644 --- a/rig/rig-derive/examples/rig_tool/simple.rs +++ b/rig/rig-derive/examples/rig_tool/simple.rs @@ -49,10 +49,10 @@ fn sum_numbers(numbers: Vec) -> Result { } #[tokio::main] -async fn main() { +async fn main() -> Result<(), anyhow::Error> { tracing_subscriber::fmt().pretty().init(); - let calculator_agent = providers::openai::Client::from_env() + let calculator_agent = providers::openai::Client::from_env()? .agent(providers::openai::GPT_4O) .preamble("You are an agent with tools access, always use the tools") .max_tokens(1024) @@ -72,6 +72,8 @@ async fn main() { "Add 100 and 200", ] { println!("User: {prompt}"); - println!("Agent: {}", calculator_agent.prompt(prompt).await.unwrap()); + println!("Agent: {}", calculator_agent.prompt(prompt).await?); } + + Ok(()) } diff --git a/rig/rig-derive/examples/rig_tool/with_description.rs b/rig/rig-derive/examples/rig_tool/with_description.rs index 43db0fab6..deeccc8d9 100644 --- a/rig/rig-derive/examples/rig_tool/with_description.rs +++ b/rig/rig-derive/examples/rig_tool/with_description.rs @@ -30,10 +30,10 @@ fn calculator(x: i32, y: i32, operation: String) -> Result Result<(), anyhow::Error> { tracing_subscriber::fmt().pretty().init(); - let calculator_agent = providers::openai::Client::from_env() + let calculator_agent = providers::openai::Client::from_env()? .agent(providers::openai::GPT_4O) .preamble("You are an agent with tools access, always use the tools") .max_tokens(1024) @@ -43,7 +43,7 @@ async fn main() { println!("Tool definition:"); println!( "CALCULATOR: {}", - serde_json::to_string_pretty(&CALCULATOR.definition(String::default()).await).unwrap() + serde_json::to_string_pretty(&CALCULATOR.definition(String::default()).await)? ); for prompt in [ @@ -55,6 +55,8 @@ async fn main() { "What is 10 / 0?", ] { println!("User: {prompt}"); - println!("Agent: {}", calculator_agent.prompt(prompt).await.unwrap()); + println!("Agent: {}", calculator_agent.prompt(prompt).await?); } + + Ok(()) } diff --git a/rig/rig-derive/src/client.rs b/rig/rig-derive/src/client.rs index 2ddccf826..db3b990fb 100644 --- a/rig/rig-derive/src/client.rs +++ b/rig/rig-derive/src/client.rs @@ -13,7 +13,10 @@ struct ClientAttr { pub fn provider_client(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); let ident = &input.ident; - let attrs = ClientAttr::parse_attributes(&input.attrs).unwrap(); + let attrs = match ClientAttr::parse_attributes(&input.attrs) { + Ok(attrs) => attrs, + Err(error) => return error.into_compile_error().into(), + }; let features: Vec = attrs.features.unwrap_or_default(); struct FeatureInfo { diff --git a/rig/rig-derive/src/custom.rs b/rig/rig-derive/src/custom.rs index e284c6910..9dbe06ef4 100644 --- a/rig/rig-derive/src/custom.rs +++ b/rig/rig-derive/src/custom.rs @@ -77,7 +77,7 @@ impl CustomAttributeParser for syn::Attribute { fn expand_tag(&self) -> syn::Result { fn function_path(meta: &ParseNestedMeta<'_>) -> syn::Result { // #[embed(embed_with = "...")] - let expr = meta.value()?.parse::().unwrap(); + let expr = meta.value()?.parse::()?; let mut value = &expr; while let syn::Expr::Group(e) = value { value = &e.expr; @@ -117,6 +117,11 @@ impl CustomAttributeParser for syn::Attribute { Err(e) => Err(e), })?; - Ok(custom_func_path.unwrap()) + custom_func_path.ok_or_else(|| { + syn::Error::new_spanned( + self, + format!("expected {EMBED_WITH} attribute: `{EMBED_WITH} = \"...\"`"), + ) + }) } } diff --git a/rig/rig-derive/src/lib.rs b/rig/rig-derive/src/lib.rs index 64545033a..4d1910daf 100644 --- a/rig/rig-derive/src/lib.rs +++ b/rig/rig-derive/src/lib.rs @@ -30,6 +30,9 @@ pub fn derive_provider_client(input: TokenStream) -> TokenStream { /// Usage can be found below: /// /// ```rust +/// use rig::Embed; +/// use rig_derive::Embed; +/// /// #[derive(Embed)] /// struct Foo { /// id: String, @@ -165,7 +168,13 @@ impl Parse for MacroArgs { .. }) = nv.value { - let param_name = nv.path.get_ident().unwrap().to_string(); + let Some(param_ident) = nv.path.get_ident() else { + return Err(syn::Error::new_spanned( + &nv.path, + "parameter descriptions must use identifier keys", + )); + }; + let param_name = param_ident.to_string(); param_descriptions.insert(param_name, lit_str.value()); } } @@ -210,13 +219,15 @@ impl Parse for MacroArgs { fn get_json_type(ty: &Type) -> proc_macro2::TokenStream { match ty { Type::Path(type_path) => { - let segment = &type_path.path.segments[0]; + let Some(segment) = type_path.path.segments.first() else { + return quote! { "type": "object" }; + }; let type_name = segment.ident.to_string(); // Handle Vec types if type_name == "Vec" { if let syn::PathArguments::AngleBracketed(args) = &segment.arguments - && let syn::GenericArgument::Type(inner_type) = &args.args[0] + && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first() { let inner_json_type = get_json_type(inner_type); return quote! { @@ -250,6 +261,68 @@ fn get_json_type(ty: &Type) -> proc_macro2::TokenStream { } } +fn result_type_tokens( + return_type: &ReturnType, +) -> syn::Result<(proc_macro2::TokenStream, proc_macro2::TokenStream)> { + let ReturnType::Type(_, ty) = return_type else { + return Err(syn::Error::new_spanned( + return_type, + "function must have a return type of Result", + )); + }; + + let Type::Path(type_path) = ty.deref() else { + return Err(syn::Error::new_spanned( + ty, + "return type must be Result", + )); + }; + + let Some(last_segment) = type_path.path.segments.last() else { + return Err(syn::Error::new_spanned( + &type_path.path, + "return type must be Result", + )); + }; + + if last_segment.ident != "Result" { + return Err(syn::Error::new_spanned( + &last_segment.ident, + "return type must be Result", + )); + } + + let PathArguments::AngleBracketed(args) = &last_segment.arguments else { + return Err(syn::Error::new_spanned( + &last_segment.arguments, + "expected angle-bracketed type parameters for Result", + )); + }; + + let mut generic_args = args.args.iter(); + let Some(output) = generic_args.next() else { + return Err(syn::Error::new_spanned( + &args.args, + "expected Result with exactly two type parameters", + )); + }; + let Some(error) = generic_args.next() else { + return Err(syn::Error::new_spanned( + &args.args, + "expected Result with exactly two type parameters", + )); + }; + + if generic_args.next().is_some() { + return Err(syn::Error::new_spanned( + &args.args, + "expected Result with exactly two type parameters", + )); + } + + Ok((quote!(#output), quote!(#error))) +} + /// A procedural macro that transforms a function into a `rig::tool::Tool` that can be used with a `rig::agent::Agent`. /// /// # Examples @@ -327,34 +400,9 @@ pub fn rig_tool(args: TokenStream, input: TokenStream) -> TokenStream { // Extract return type and get Output and Error types from Result let return_type = &input_fn.sig.output; - let (output_type, error_type) = match return_type { - ReturnType::Type(_, ty) => { - if let Type::Path(type_path) = ty.deref() { - if let Some(last_segment) = type_path.path.segments.last() { - if last_segment.ident == "Result" { - if let PathArguments::AngleBracketed(args) = &last_segment.arguments { - if args.args.len() == 2 { - let output = args.args.first().unwrap(); - let error = args.args.last().unwrap(); - - (quote!(#output), quote!(#error)) - } else { - panic!("Expected Result with two type parameters"); - } - } else { - panic!("Expected angle bracketed type parameters for Result"); - } - } else { - panic!("Return type must be a Result"); - } - } else { - panic!("Invalid return type"); - } - } else { - panic!("Invalid return type"); - } - } - _ => panic!("Function must have a return type"), + let (output_type, error_type) = match result_type_tokens(return_type) { + Ok(types) => types, + Err(error) => return error.into_compile_error().into(), }; // Generate PascalCase struct name from the function name diff --git a/rig/rig-derive/tests/calculator.rs b/rig/rig-derive/tests/calculator.rs index 0a4b85f12..2eddfced9 100644 --- a/rig/rig-derive/tests/calculator.rs +++ b/rig/rig-derive/tests/calculator.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + use rig::tool::Tool; use rig_derive::rig_tool; diff --git a/rig/rig-derive/tests/custom_name.rs b/rig/rig-derive/tests/custom_name.rs index 453982ab5..66098cf4c 100644 --- a/rig/rig-derive/tests/custom_name.rs +++ b/rig/rig-derive/tests/custom_name.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + use rig::tool::Tool; use rig_derive::rig_tool; diff --git a/rig/rig-derive/tests/visibility.rs b/rig/rig-derive/tests/visibility.rs index 4e9c198f2..e205a087e 100644 --- a/rig/rig-derive/tests/visibility.rs +++ b/rig/rig-derive/tests/visibility.rs @@ -1,3 +1,11 @@ +#![allow( + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic, + clippy::unwrap_used, + clippy::unreachable +)] + //! Test that `#[rig_tool]` propagates the function's visibility to the generated //! structs and static. A `pub` function should produce a `pub` tool struct that //! is accessible from outside the defining module. From 349bcdfb7e05c15c3a39097cfc0a3d83bb631f0e Mon Sep 17 00:00:00 2001 From: stephen Date: Thu, 23 Apr 2026 14:25:58 -0700 Subject: [PATCH 2/5] second pass --- .../rig-vectorize/src/client/filter.rs | 18 ++++++++++++------ .../openai_chat_completions_compatible.rs | 9 ++++----- rig/rig-derive/src/lib.rs | 10 +++++----- 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/rig-integrations/rig-vectorize/src/client/filter.rs b/rig-integrations/rig-vectorize/src/client/filter.rs index 001768c7f..b58643e94 100644 --- a/rig-integrations/rig-vectorize/src/client/filter.rs +++ b/rig-integrations/rig-vectorize/src/client/filter.rs @@ -191,8 +191,11 @@ mod tests { let combined = filter1.and(filter2); let result = combined.into_inner(); - let Value::Object(obj) = result else { - assert!(false, "combined filter should serialize to an object"); + let Some(obj) = result.as_object() else { + assert!( + result.is_object(), + "combined filter should serialize to an object" + ); return; }; @@ -216,7 +219,7 @@ mod tests { let err = match result { Err(err) => err, Ok(()) => { - assert!(false, "OR filters should fail validation"); + assert!(result.is_err(), "OR filters should fail validation"); return; } }; @@ -225,7 +228,7 @@ mod tests { assert!(msg.contains("OR")); } other => assert!( - false, + matches!(other, VectorizeError::UnsupportedFilterOperation(_)), "expected UnsupportedFilterOperation error, got {other:?}" ), } @@ -251,8 +254,11 @@ mod tests { .and(VectorizeFilter::lt("price", json!(100))); let result = filter.into_inner(); - let Value::Object(obj) = result else { - assert!(false, "combined filter should serialize to an object"); + let Some(obj) = result.as_object() else { + assert!( + result.is_object(), + "combined filter should serialize to an object" + ); return; }; diff --git a/rig/rig-core/src/providers/internal/openai_chat_completions_compatible.rs b/rig/rig-core/src/providers/internal/openai_chat_completions_compatible.rs index 7f526683e..70e55e01c 100644 --- a/rig/rig-core/src/providers/internal/openai_chat_completions_compatible.rs +++ b/rig/rig-core/src/providers/internal/openai_chat_completions_compatible.rs @@ -250,12 +250,11 @@ where for incoming in choice.tool_calls { if let Some(existing) = tool_calls.get(&incoming.index) && profile.should_evict(existing, &incoming) + && let Some(evicted) = tool_calls.remove(&incoming.index) { - if let Some(evicted) = tool_calls.remove(&incoming.index) { - yield Ok(RawStreamingChoice::ToolCall( - finalize_completed_streaming_tool_call(evicted), - )); - } + yield Ok(RawStreamingChoice::ToolCall( + finalize_completed_streaming_tool_call(evicted), + )); } let existing_tool_call = tool_calls diff --git a/rig/rig-derive/src/lib.rs b/rig/rig-derive/src/lib.rs index 4d1910daf..2bd6bf317 100644 --- a/rig/rig-derive/src/lib.rs +++ b/rig/rig-derive/src/lib.rs @@ -29,7 +29,7 @@ pub fn derive_provider_client(input: TokenStream) -> TokenStream { /// A macro that allows you to implement the `rig::embedding::Embed` trait by deriving it. /// Usage can be found below: /// -/// ```rust +/// ```text /// use rig::Embed; /// use rig_derive::Embed; /// @@ -328,7 +328,7 @@ fn result_type_tokens( /// # Examples /// /// Basic usage: -/// ```rust +/// ```text /// use rig_derive::rig_tool; /// /// #[rig_tool] @@ -338,7 +338,7 @@ fn result_type_tokens( /// ``` /// /// With description: -/// ```rust +/// ```text /// use rig_derive::rig_tool; /// /// #[rig_tool(description = "Perform basic arithmetic operations")] @@ -354,7 +354,7 @@ fn result_type_tokens( /// ``` /// /// With a custom tool name: -/// ```rust +/// ```text /// use rig_derive::rig_tool; /// /// // Explicit names must be string literals that start with an ASCII letter @@ -367,7 +367,7 @@ fn result_type_tokens( /// ``` /// /// With parameter descriptions: -/// ```rust +/// ```text /// use rig_derive::rig_tool; /// /// #[rig_tool( From 93b8bffef9cb5ab9d7f9a418eba14a458aa969d2 Mon Sep 17 00:00:00 2001 From: stephen Date: Thu, 23 Apr 2026 15:50:55 -0700 Subject: [PATCH 3/5] doc updates --- rig-integrations/rig-lancedb/src/lib.rs | 9 ++++--- rig-integrations/rig-vectorize/src/lib.rs | 3 ++- rig/rig-core/src/agent/builder.rs | 4 ++-- rig/rig-core/src/agent/completion.rs | 4 ++-- rig/rig-core/src/agent/mod.rs | 6 +++-- rig/rig-core/src/client/mod.rs | 24 +++++++++++++++++++ rig/rig-core/src/client/model_listing.rs | 4 ++-- rig/rig-core/src/lib.rs | 8 ++++--- rig/rig-core/src/providers/chatgpt/mod.rs | 8 +++++-- rig/rig-core/src/providers/copilot/mod.rs | 8 +++++-- .../src/providers/openai/responses_api/mod.rs | 8 ++++++- 11 files changed, 66 insertions(+), 20 deletions(-) diff --git a/rig-integrations/rig-lancedb/src/lib.rs b/rig-integrations/rig-lancedb/src/lib.rs index 61a428a9b..b1521a03b 100644 --- a/rig-integrations/rig-lancedb/src/lib.rs +++ b/rig-integrations/rig-lancedb/src/lib.rs @@ -40,9 +40,10 @@ fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError { /// # Example /// ```ignore /// use rig_lancedb::{LanceDbVectorIndex, SearchParams}; +/// use rig::client::ProviderClient; /// use rig::providers::openai::{Client, TEXT_EMBEDDING_ADA_002, EmbeddingModel}; /// -/// let openai_client = Client::from_env(); +/// let openai_client = Client::from_env()?; /// /// let table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here. /// let model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here. @@ -376,9 +377,10 @@ where /// # Example /// ```ignore /// use rig_lancedb::{LanceDbVectorIndex, SearchParams}; + /// use rig::client::ProviderClient; /// use rig::providers::openai::{EmbeddingModel, Client, TEXT_EMBEDDING_ADA_002}; /// - /// let openai_client = Client::from_env(); + /// let openai_client = Client::from_env()?; /// /// let table: lancedb::Table = db.create_table("fake_definitions"); // <-- Replace with your lancedb table here. /// let model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here. @@ -438,9 +440,10 @@ where /// # Example /// ```ignore /// use rig_lancedb::{LanceDbVectorIndex, SearchParams}; + /// use rig::client::ProviderClient; /// use rig::providers::openai::{Client, TEXT_EMBEDDING_ADA_002, EmbeddingModel}; /// - /// let openai_client = Client::from_env(); + /// let openai_client = Client::from_env()?; /// /// let table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here. /// let model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here. diff --git a/rig-integrations/rig-vectorize/src/lib.rs b/rig-integrations/rig-vectorize/src/lib.rs index c804b9244..07b72bbd6 100644 --- a/rig-integrations/rig-vectorize/src/lib.rs +++ b/rig-integrations/rig-vectorize/src/lib.rs @@ -6,10 +6,11 @@ //! # Example //! //! ```ignore +//! use rig::client::ProviderClient; //! use rig::providers::openai; //! use rig_vectorize::VectorizeVectorStore; //! -//! let openai = openai::Client::from_env(); +//! let openai = openai::Client::from_env()?; //! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_3_SMALL); //! //! let vector_store = VectorizeVectorStore::new( diff --git a/rig/rig-core/src/agent/builder.rs b/rig/rig-core/src/agent/builder.rs index 697361c6b..96ec64498 100644 --- a/rig/rig-core/src/agent/builder.rs +++ b/rig/rig-core/src/agent/builder.rs @@ -56,9 +56,9 @@ pub struct WithBuilderTools { /// /// # Example /// ``` -/// use rig::{providers::openai, agent::AgentBuilder}; +/// use rig::{client::ProviderClient, providers::openai, agent::AgentBuilder}; /// -/// let openai = openai::Client::from_env(); +/// let openai = openai::Client::from_env()?; /// /// let gpt4o = openai.completion_model("gpt-4o"); /// diff --git a/rig/rig-core/src/agent/completion.rs b/rig/rig-core/src/agent/completion.rs index a6c3aeaeb..09953921c 100644 --- a/rig/rig-core/src/agent/completion.rs +++ b/rig/rig-core/src/agent/completion.rs @@ -149,9 +149,9 @@ pub(crate) async fn build_completion_request( /// /// # Example /// ``` -/// use rig::{completion::Prompt, providers::openai}; +/// use rig::{client::ProviderClient, completion::Prompt, providers::openai}; /// -/// let openai = openai::Client::from_env(); +/// let openai = openai::Client::from_env()?; /// /// let comedian_agent = openai /// .agent("gpt-4o") diff --git a/rig/rig-core/src/agent/mod.rs b/rig/rig-core/src/agent/mod.rs index 7c3461e66..122fde8f8 100644 --- a/rig/rig-core/src/agent/mod.rs +++ b/rig/rig-core/src/agent/mod.rs @@ -19,11 +19,12 @@ //! # Example //! ```rust //! use rig::{ +//! client::ProviderClient, //! completion::{Chat, Completion, Prompt}, //! providers::openai, //! }; //! -//! let openai = openai::Client::from_env(); +//! let openai = openai::Client::from_env()?; //! //! // Configure the agent //! let agent = openai.agent("gpt-4o") @@ -64,6 +65,7 @@ //! RAG Agent example //! ```rust //! use rig::{ +//! client::ProviderClient, //! completion::Prompt, //! embeddings::EmbeddingsBuilder, //! providers::openai, @@ -71,7 +73,7 @@ //! }; //! //! // Initialize OpenAI client -//! let openai = openai::Client::from_env(); +//! let openai = openai::Client::from_env()?; //! //! // Initialize OpenAI embedding model //! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002); diff --git a/rig/rig-core/src/client/mod.rs b/rig/rig-core/src/client/mod.rs index 1baf8b757..d3cca0a65 100644 --- a/rig/rig-core/src/client/mod.rs +++ b/rig/rig-core/src/client/mod.rs @@ -53,27 +53,48 @@ pub enum ClientBuilderError { InvalidProperty(&'static str), } +/// Errors returned while constructing provider clients from environment variables or explicit input. +/// +/// Provider-specific client constructors use this error for configuration problems that can be +/// detected before any model request is sent, such as missing API keys, invalid environment +/// values, or invalid builder configuration. #[derive(Debug, Error)] #[non_exhaustive] pub enum ProviderClientError { + /// A required or optional environment variable could not be read as valid Unicode. + /// + /// For required variables, this variant is also returned when the variable is not present. #[error("environment variable `{name}` is not set or is invalid")] EnvironmentVariable { + /// The environment variable name. name: &'static str, + /// The underlying environment lookup error. #[source] source: VarError, }, + /// The underlying provider client builder failed while constructing HTTP configuration. #[error(transparent)] Http(#[from] http_client::Error), + /// The provider received an unsupported or incomplete configuration. #[error("{0}")] InvalidConfiguration(&'static str), } +/// Result type returned by provider client construction helpers. pub type ProviderClientResult = std::result::Result; +/// Read a required environment variable for provider client construction. +/// +/// Returns [`ProviderClientError::EnvironmentVariable`] when the variable is missing or contains +/// invalid Unicode. pub fn required_env_var(name: &'static str) -> ProviderClientResult { std::env::var(name).map_err(|source| ProviderClientError::EnvironmentVariable { name, source }) } +/// Read an optional environment variable for provider client construction. +/// +/// Missing variables return `Ok(None)`. Variables containing invalid Unicode return +/// [`ProviderClientError::EnvironmentVariable`]. pub fn optional_env_var(name: &'static str) -> ProviderClientResult> { match std::env::var(name) { Ok(value) => Ok(Some(value)), @@ -85,7 +106,9 @@ pub fn optional_env_var(name: &'static str) -> ProviderClientResult Result where Self: Sized; diff --git a/rig/rig-core/src/client/model_listing.rs b/rig/rig-core/src/client/model_listing.rs index 3e7a244c7..5ce907b8e 100644 --- a/rig/rig-core/src/client/model_listing.rs +++ b/rig/rig-core/src/client/model_listing.rs @@ -55,10 +55,10 @@ pub trait ModelListingClient { /// # Example /// /// ```rust,ignore - /// use rig::client::ModelListingClient; + /// use rig::client::{ModelListingClient, ProviderClient}; /// use rig::providers::openai::Client; /// - /// let openai = Client::from_env(); + /// let openai = Client::from_env()?; /// let models = openai.list_models().await?; /// /// println!("Found {} models", models.len()); diff --git a/rig/rig-core/src/lib.rs b/rig/rig-core/src/lib.rs index a38d9aad2..36f9be53d 100644 --- a/rig/rig-core/src/lib.rs +++ b/rig/rig-core/src/lib.rs @@ -24,13 +24,13 @@ //! //! # Simple example: //! ``` -//! use rig::{client::CompletionClient, completion::Prompt, providers::openai}; +//! use rig::{client::{CompletionClient, ProviderClient}, completion::Prompt, providers::openai}; //! //! #[tokio::main] -//! async fn main() { +//! async fn main() -> Result<(), Box> { //! // Create OpenAI client and agent. //! // This requires the `OPENAI_API_KEY` environment variable to be set. -//! let openai_client = openai::Client::from_env(); +//! let openai_client = openai::Client::from_env()?; //! //! let gpt4 = openai_client.agent("gpt-4").build(); //! @@ -41,6 +41,8 @@ //! .expect("Failed to prompt GPT-4"); //! //! println!("GPT-4: {response}"); +//! +//! Ok(()) //! } //! ``` //! Note: using `#[tokio::main]` requires you enable tokio's `macros` and `rt-multi-thread` features diff --git a/rig/rig-core/src/providers/chatgpt/mod.rs b/rig/rig-core/src/providers/chatgpt/mod.rs index ef1bc105c..9e248af59 100644 --- a/rig/rig-core/src/providers/chatgpt/mod.rs +++ b/rig/rig-core/src/providers/chatgpt/mod.rs @@ -5,11 +5,15 @@ //! //! # Example //! ```no_run -//! use rig::client::CompletionClient; +//! use rig::client::{CompletionClient, ProviderClient}; //! use rig::providers::chatgpt; //! -//! let client = chatgpt::Client::from_env(); +//! # fn example() -> Result<(), Box> { +//! let client = chatgpt::Client::from_env()?; //! let model = client.completion_model(chatgpt::GPT_5_3_CODEX); +//! # let _ = model; +//! # Ok(()) +//! # } //! ``` mod auth; diff --git a/rig/rig-core/src/providers/copilot/mod.rs b/rig/rig-core/src/providers/copilot/mod.rs index 98f3cf4dc..0bde2bd84 100644 --- a/rig/rig-core/src/providers/copilot/mod.rs +++ b/rig/rig-core/src/providers/copilot/mod.rs @@ -9,11 +9,15 @@ //! //! # Example //! ```no_run -//! use rig::client::CompletionClient; +//! use rig::client::{CompletionClient, ProviderClient}; //! use rig::providers::copilot; //! -//! let client = copilot::Client::from_env(); +//! # fn example() -> Result<(), Box> { +//! let client = copilot::Client::from_env()?; //! let model = client.completion_model(copilot::GPT_4O); +//! # let _ = model; +//! # Ok(()) +//! # } //! ``` mod auth; diff --git a/rig/rig-core/src/providers/openai/responses_api/mod.rs b/rig/rig-core/src/providers/openai/responses_api/mod.rs index c3388ef8b..1b2aded6b 100644 --- a/rig/rig-core/src/providers/openai/responses_api/mod.rs +++ b/rig/rig-core/src/providers/openai/responses_api/mod.rs @@ -4,8 +4,14 @@ //! //! If you'd like to switch back to the regular Completions API, you can do so by using the `.completions_api()` function - see below for an example: //! ```rust -//! let openai_client = rig::providers::openai::Client::from_env(); +//! use rig::client::{CompletionClient, ProviderClient}; +//! +//! # fn example() -> Result<(), Box> { +//! let openai_client = rig::providers::openai::Client::from_env()?; //! let model = openai_client.completion_model("gpt-4o").completions_api(); +//! # let _ = model; +//! # Ok(()) +//! # } //! ``` use super::InputAudio; use super::completion::ToolChoice; From abf694375b7c3c5778e5c83588dba64d49d35262 Mon Sep 17 00:00:00 2001 From: stephen Date: Thu, 23 Apr 2026 18:37:10 -0700 Subject: [PATCH 4/5] more lints --- Cargo.toml | 4 + .../examples/vector_search_surreal.rs | 6 +- .../rig-surrealdb/examples/vector_store.rs | 24 +++- .../src/agent/prompt_request/streaming.rs | 8 +- rig/rig-core/src/providers/azure.rs | 14 ++- .../src/providers/gemini/completion.rs | 2 +- .../anthropic/think_tool_with_other_tools.rs | 15 ++- rig/rig-core/tests/chatgpt/extractor_usage.rs | 57 ++++----- rig/rig-core/tests/chatgpt/multi_extract.rs | 6 +- .../tests/chatgpt/permission_control.rs | 10 +- rig/rig-core/tests/chatgpt/request_hook.rs | 8 +- rig/rig-core/tests/copilot/extractor_usage.rs | 55 ++++----- rig/rig-core/tests/copilot/multi_extract.rs | 2 +- .../tests/copilot/permission_control.rs | 10 +- rig/rig-core/tests/copilot/request_hook.rs | 8 +- .../tests/copilot/typed_prompt_tools.rs | 2 +- .../tests/deepseek/extractor_usage.rs | 111 +++++++++++++----- rig/rig-core/tests/deepseek/multi_extract.rs | 6 +- .../tests/deepseek/permission_control.rs | 22 +++- rig/rig-core/tests/deepseek/request_hook.rs | 20 +++- rig/rig-core/tests/groq/extractor_usage.rs | 57 ++++----- rig/rig-core/tests/groq/multi_extract.rs | 6 +- rig/rig-core/tests/groq/permission_control.rs | 10 +- rig/rig-core/tests/groq/request_hook.rs | 8 +- rig/rig-core/tests/groq/typed_prompt_tools.rs | 2 +- .../tests/llamacpp/extractor_usage.rs | 57 ++++----- rig/rig-core/tests/llamacpp/multi_extract.rs | 6 +- .../tests/llamacpp/permission_control.rs | 10 +- rig/rig-core/tests/llamacpp/request_hook.rs | 8 +- .../tests/llamacpp/typed_prompt_tools.rs | 2 +- .../tests/llamafile/extractor_usage.rs | 55 ++++----- rig/rig-core/tests/llamafile/multi_extract.rs | 2 +- .../tests/llamafile/permission_control.rs | 10 +- rig/rig-core/tests/llamafile/request_hook.rs | 8 +- .../tests/llamafile/typed_prompt_tools.rs | 2 +- rig/rig-core/tests/mistral/extractor_usage.rs | 55 ++++----- rig/rig-core/tests/mistral/multi_extract.rs | 2 +- .../tests/mistral/permission_control.rs | 10 +- rig/rig-core/tests/mistral/request_hook.rs | 8 +- .../tests/mistral/typed_prompt_tools.rs | 2 +- rig/rig-core/tests/openai/extractor_usage.rs | 55 ++++----- rig/rig-core/tests/openai/multi_extract.rs | 2 +- .../tests/openai/permission_control.rs | 10 +- rig/rig-core/tests/openai/request_hook.rs | 8 +- .../tests/openai/typed_prompt_tools.rs | 2 +- rig/rig-core/tests/openai/websocket.rs | 2 +- .../tests/openrouter/extractor_usage.rs | 55 ++++----- .../tests/openrouter/multi_extract.rs | 2 +- .../tests/openrouter/permission_control.rs | 10 +- rig/rig-core/tests/openrouter/request_hook.rs | 8 +- .../tests/openrouter/typed_prompt_tools.rs | 2 +- rig/rig-core/tests/xai/extractor_usage.rs | 55 ++++----- rig/rig-core/tests/xai/multi_extract.rs | 2 +- rig/rig-core/tests/xai/permission_control.rs | 10 +- rig/rig-core/tests/xai/request_hook.rs | 8 +- rig/rig-core/tests/xai/typed_prompt_tools.rs | 2 +- 56 files changed, 537 insertions(+), 406 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e91293fb7..c8bcbf5de 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,9 +10,13 @@ exclude = [ [workspace.lints.clippy] dbg_macro = "forbid" +await_holding_lock = "deny" +await_holding_refcell_ref = "deny" expect_used = "deny" +expect_fun_call = "deny" indexing_slicing = "deny" panic = "deny" +panic_in_result_fn = "deny" todo = "forbid" unimplemented = "forbid" unreachable = "deny" diff --git a/rig-integrations/rig-surrealdb/examples/vector_search_surreal.rs b/rig-integrations/rig-surrealdb/examples/vector_search_surreal.rs index 544381cf4..c40165743 100644 --- a/rig-integrations/rig-surrealdb/examples/vector_search_surreal.rs +++ b/rig-integrations/rig-surrealdb/examples/vector_search_surreal.rs @@ -99,7 +99,11 @@ async fn main() -> Result<(), anyhow::Error> { let results = vector_store.top_n::(req).await?; println!("{} results for query: {}", results.len(), query); - assert_eq!(results.len(), 1); + anyhow::ensure!( + results.len() == 1, + "expected one result after threshold filtering, got {}", + results.len() + ); for (distance, _id, doc) in results.iter() { println!("Result distance {distance} for word: {doc}"); diff --git a/rig-integrations/rig-surrealdb/examples/vector_store.rs b/rig-integrations/rig-surrealdb/examples/vector_store.rs index 0b4eb1cc6..af9264578 100644 --- a/rig-integrations/rig-surrealdb/examples/vector_store.rs +++ b/rig-integrations/rig-surrealdb/examples/vector_store.rs @@ -70,11 +70,19 @@ async fn main() -> Result<(), anyhow::Error> { let results = vector_store.top_n::(req).await?; - assert_eq!(results.len(), 3); + anyhow::ensure!( + results.len() == 3, + "expected three unfiltered results, got {}", + results.len() + ); let Some(first_result) = results.first() else { return Err(anyhow::anyhow!("expected at least one result")); }; - assert_eq!(first_result.2.topic, "pasta carbonara"); + anyhow::ensure!( + first_result.2.topic == "pasta carbonara", + "expected first result to be pasta carbonara, got {}", + first_result.2.topic + ); println!("{} results for query: {}", results.len(), query); for (distance, _id, doc) in results.iter() { @@ -98,11 +106,19 @@ async fn main() -> Result<(), anyhow::Error> { let results = vector_store.top_n::(req).await?; println!("{} results for query: {}", results.len(), query); - assert_eq!(results.len(), 1); + anyhow::ensure!( + results.len() == 1, + "expected one filtered result, got {}", + results.len() + ); let Some(filtered_result) = results.first() else { return Err(anyhow::anyhow!("expected one filtered result")); }; - assert_eq!(filtered_result.2.topic, "pasta carbonara"); + anyhow::ensure!( + filtered_result.2.topic == "pasta carbonara", + "expected filtered result to be pasta carbonara, got {}", + filtered_result.2.topic + ); for (distance, _id, doc) in results.iter() { println!("Result distance {distance} for topic: {doc}"); diff --git a/rig/rig-core/src/agent/prompt_request/streaming.rs b/rig/rig-core/src/agent/prompt_request/streaming.rs index df07702b4..b69a38826 100644 --- a/rig/rig-core/src/agent/prompt_request/streaming.rs +++ b/rig/rig-core/src/agent/prompt_request/streaming.rs @@ -1374,8 +1374,8 @@ mod tests { bg_handle.await?; let leaks = leak_count.load(Ordering::Relaxed); - assert_eq!( - leaks, 0, + anyhow::ensure!( + leaks == 0, "SPAN LEAK DETECTED: Background logger was inside unexpected spans {leaks} times. \ This indicates that span.enter() is being used inside async_stream instead of .instrument()" ); @@ -1433,13 +1433,13 @@ mod tests { .ok_or_else(|| anyhow::anyhow!("final response should include history"))?; // Should contain at least the user message - assert!( + anyhow::ensure!( history.iter().any(|m| matches!(m, Message::User { .. })), "History should contain the user message" ); // Should contain the assistant response - assert!( + anyhow::ensure!( history .iter() .any(|m| matches!(m, Message::Assistant { .. })), diff --git a/rig/rig-core/src/providers/azure.rs b/rig/rig-core/src/providers/azure.rs index 09dd0b93e..776ecdb58 100644 --- a/rig/rig-core/src/providers/azure.rs +++ b/rig/rig-core/src/providers/azure.rs @@ -1107,7 +1107,11 @@ mod azure_tests { let model = client.embedding_model_with_ndims(TEXT_EMBEDDING_3_SMALL, ndims); let embedding = model.embed_text("Hello, world!").await?; - assert!(embedding.vec.len() == ndims); + anyhow::ensure!( + embedding.vec.len() == ndims, + "expected embedding dimensions {ndims}, got {}", + embedding.vec.len() + ); tracing::info!("Azure dimensions embedding: {:?}", embedding); Ok(()) @@ -1162,8 +1166,12 @@ mod azure_tests { .prompt_typed("Hello! My name is John Doe and I'm 54 years old.") .await?; - assert!(result.name == "John Doe"); - assert!(result.age == 54); + anyhow::ensure!( + result.name == "John Doe", + "expected name John Doe, got {}", + result.name + ); + anyhow::ensure!(result.age == 54, "expected age 54, got {}", result.age); tracing::info!("Extracted person: {:?}", result); Ok(()) diff --git a/rig/rig-core/src/providers/gemini/completion.rs b/rig/rig-core/src/providers/gemini/completion.rs index d566aa082..7b97fb86e 100644 --- a/rig/rig-core/src/providers/gemini/completion.rs +++ b/rig/rig-core/src/providers/gemini/completion.rs @@ -2951,7 +2951,7 @@ mod tests { .await?; println!("Response: {response_text}"); // Gemini should have been able to see the image and potentially describe its color - assert!(!response_text.is_empty(), "Response should not be empty"); + anyhow::ensure!(!response_text.is_empty(), "Response should not be empty"); Ok(()) } diff --git a/rig/rig-core/tests/anthropic/think_tool_with_other_tools.rs b/rig/rig-core/tests/anthropic/think_tool_with_other_tools.rs index e2300510a..96e52a143 100644 --- a/rig/rig-core/tests/anthropic/think_tool_with_other_tools.rs +++ b/rig/rig-core/tests/anthropic/think_tool_with_other_tools.rs @@ -311,8 +311,7 @@ async fn think_tool_with_other_tools() -> Result<()> { ) .max_turns(10) .extended_details() - .await - .expect("prompt should succeed"); + .await?; assert_mentions_expected_number(&response.output, 25); assert_contains_any_case_insensitive( @@ -320,22 +319,22 @@ async fn think_tool_with_other_tools() -> Result<()> { &["out of stock", "express shipping", "110.99", "$110.99"], ); - assert!( + anyhow::ensure!( calculator_calls.load(Ordering::SeqCst) >= 1, "calculator should be invoked at least once" ); - assert!( + anyhow::ensure!( database_lookup_calls.load(Ordering::SeqCst) >= 2, "database lookup should be invoked for both shipping and inventory" ); let messages = response .messages - .expect("extended details should include messages"); + .ok_or_else(|| anyhow::anyhow!("extended details should include messages"))?; let tool_calls = collect_assistant_tool_calls(&messages); for tool_name in ["think", "calculator", "database_lookup"] { - assert!( + anyhow::ensure!( tool_calls.iter().any(|(name, _)| name == tool_name), "expected a {tool_name} tool call, saw {:?}", tool_calls @@ -350,11 +349,11 @@ async fn think_tool_with_other_tools() -> Result<()> { .filter(|(name, _)| name == "database_lookup") .filter_map(|(_, args)| args.get("query").and_then(|value| value.as_str())) .collect::>(); - assert!( + anyhow::ensure!( queries.contains(&"product_inventory"), "expected product_inventory lookup, saw {queries:?}" ); - assert!( + anyhow::ensure!( queries.contains(&"shipping_rates"), "expected shipping_rates lookup, saw {queries:?}" ); diff --git a/rig/rig-core/tests/chatgpt/extractor_usage.rs b/rig/rig-core/tests/chatgpt/extractor_usage.rs index 7084b8828..30900ae92 100644 --- a/rig/rig-core/tests/chatgpt/extractor_usage.rs +++ b/rig/rig-core/tests/chatgpt/extractor_usage.rs @@ -1,6 +1,6 @@ //! Integration tests for ChatGPT extractor usage tracking. -use anyhow::Result; +use anyhow::{Result, anyhow}; use rig::client::CompletionClient; use rig::extractor::ExtractionResponse; use schemars::JsonSchema; @@ -23,20 +23,21 @@ struct Address { zip_code: Option, } -fn assert_compatible_professions(left: Option<&str>, right: Option<&str>) { +fn assert_compatible_professions(left: Option<&str>, right: Option<&str>) -> Result<()> { let left = left - .expect("profession should be present") + .ok_or_else(|| anyhow!("profession should be present"))? .trim() .to_ascii_lowercase(); let right = right - .expect("profession should be present") + .ok_or_else(|| anyhow!("profession should be present"))? .trim() .to_ascii_lowercase(); - assert!( + anyhow::ensure!( left == right || left.contains(&right) || right.contains(&left), "expected compatible professions, got {left:?} and {right:?}" ); + Ok(()) } #[tokio::test] @@ -48,9 +49,9 @@ async fn extract_backward_compatibility() -> Result<()> { .extract("John Doe is a 30 year old software engineer.") .await?; - assert_eq!(person.name, Some("John Doe".to_string())); - assert_eq!(person.age, Some(30)); - assert_eq!(person.profession, Some("software engineer".to_string())); + anyhow::ensure!(person.name.as_deref() == Some("John Doe")); + anyhow::ensure!(person.age == Some(30)); + anyhow::ensure!(person.profession.as_deref() == Some("software engineer")); Ok(()) } @@ -64,12 +65,12 @@ async fn extract_with_usage_returns_data_and_usage() -> Result<()> { .extract_with_usage("Jane Smith is a 45 year old data scientist.") .await?; - assert_eq!(response.data.name, Some("Jane Smith".to_string())); - assert_eq!(response.data.age, Some(45)); - assert_eq!(response.data.profession, Some("data scientist".to_string())); - assert!(response.usage.input_tokens > 0); - assert!(response.usage.output_tokens > 0); - assert!(response.usage.total_tokens > 0); + anyhow::ensure!(response.data.name.as_deref() == Some("Jane Smith")); + anyhow::ensure!(response.data.age == Some(45)); + anyhow::ensure!(response.data.profession.as_deref() == Some("data scientist")); + anyhow::ensure!(response.usage.input_tokens > 0); + anyhow::ensure!(response.usage.output_tokens > 0); + anyhow::ensure!(response.usage.total_tokens > 0); Ok(()) } @@ -92,12 +93,12 @@ async fn extract_with_chat_history_with_usage_works() -> Result<()> { ) .await?; - assert_eq!(response.data.street, Some("123 Main St".to_string())); - assert_eq!(response.data.city, Some("Springfield".to_string())); - assert_eq!(response.data.state, Some("IL".to_string())); - assert_eq!(response.data.zip_code, Some("62701".to_string())); - assert!(response.usage.input_tokens > 0); - assert!(response.usage.total_tokens > 0); + anyhow::ensure!(response.data.street.as_deref() == Some("123 Main St")); + anyhow::ensure!(response.data.city.as_deref() == Some("Springfield")); + anyhow::ensure!(response.data.state.as_deref() == Some("IL")); + anyhow::ensure!(response.data.zip_code.as_deref() == Some("62701")); + anyhow::ensure!(response.usage.input_tokens > 0); + anyhow::ensure!(response.usage.total_tokens > 0); Ok(()) } @@ -112,15 +113,15 @@ async fn extract_and_extract_with_usage_return_same_data() -> Result<()> { let person = extractor.extract(text).await?; let response = extractor.extract_with_usage(text).await?; - assert_eq!(person.name, Some("Bob Johnson".to_string())); - assert_eq!(response.data.name, Some("Bob Johnson".to_string())); - assert_eq!(person.age, Some(55)); - assert_eq!(response.data.age, Some(55)); + anyhow::ensure!(person.name.as_deref() == Some("Bob Johnson")); + anyhow::ensure!(response.data.name.as_deref() == Some("Bob Johnson")); + anyhow::ensure!(person.age == Some(55)); + anyhow::ensure!(response.data.age == Some(55)); assert_compatible_professions( person.profession.as_deref(), response.data.profession.as_deref(), - ); - assert!(response.usage.total_tokens > 0, "usage should be populated"); + )?; + anyhow::ensure!(response.usage.total_tokens > 0, "usage should be populated"); Ok(()) } @@ -134,13 +135,13 @@ async fn usage_tracking_works_for_different_schemas() -> Result<()> { let person_response = person_extractor .extract_with_usage("Alice is a 25 year old developer.") .await?; - assert!(person_response.usage.total_tokens > 0); + anyhow::ensure!(person_response.usage.total_tokens > 0); let address_extractor = client.extractor::
(LIVE_MODEL).build(); let address_response = address_extractor .extract_with_usage("456 Oak Avenue, Cambridge, MA 02139") .await?; - assert!(address_response.usage.total_tokens > 0); + anyhow::ensure!(address_response.usage.total_tokens > 0); Ok(()) } diff --git a/rig/rig-core/tests/chatgpt/multi_extract.rs b/rig/rig-core/tests/chatgpt/multi_extract.rs index 05e54f094..6dc653248 100644 --- a/rig/rig-core/tests/chatgpt/multi_extract.rs +++ b/rig/rig-core/tests/chatgpt/multi_extract.rs @@ -73,7 +73,11 @@ async fn batch_multi_extract_chain() -> Result<()> { ) .await?; - assert_eq!(responses.len(), 3); + anyhow::ensure!( + responses.len() == 3, + "expected three responses, got {}", + responses.len() + ); for response in responses { assert_nonempty_response(&response); } diff --git a/rig/rig-core/tests/chatgpt/permission_control.rs b/rig/rig-core/tests/chatgpt/permission_control.rs index c4c62bd50..65c83afae 100644 --- a/rig/rig-core/tests/chatgpt/permission_control.rs +++ b/rig/rig-core/tests/chatgpt/permission_control.rs @@ -177,8 +177,8 @@ async fn permission_control_prompt_example() -> Result<()> { .await?; let last = last_result.lock().expect("lock last_result").clone(); - assert_eq!(last.as_deref(), Some("hello world")); - assert_eq!(call_count.load(Ordering::SeqCst), 2); + anyhow::ensure!(last.as_deref() == Some("hello world")); + anyhow::ensure!(call_count.load(Ordering::SeqCst) == 2); Ok(()) } @@ -213,7 +213,7 @@ async fn permission_control_streaming_example() -> Result<()> { let final_response = stream_to_stdout(&mut stream).await?; let last = last_result.lock().expect("lock last_result").clone(); assert_nonempty_response(final_response.response()); - assert!( + anyhow::ensure!( final_response .response() .to_ascii_lowercase() @@ -221,8 +221,8 @@ async fn permission_control_streaming_example() -> Result<()> { "expected the streamed final response to mention the file content, got {:?}", final_response.response() ); - assert_eq!(last.as_deref(), Some("hello world")); - assert_eq!(call_count.load(Ordering::SeqCst), 2); + anyhow::ensure!(last.as_deref() == Some("hello world")); + anyhow::ensure!(call_count.load(Ordering::SeqCst) == 2); Ok(()) } diff --git a/rig/rig-core/tests/chatgpt/request_hook.rs b/rig/rig-core/tests/chatgpt/request_hook.rs index fc7bcdec6..df9fe1e22 100644 --- a/rig/rig-core/tests/chatgpt/request_hook.rs +++ b/rig/rig-core/tests/chatgpt/request_hook.rs @@ -87,8 +87,8 @@ async fn request_hook_records_prompt_and_response() -> Result<()> { .await?; assert_nonempty_response(&response); - assert_eq!(hook.prompt_calls.load(Ordering::SeqCst), 1); - assert_eq!(hook.response_calls.load(Ordering::SeqCst), 1); + anyhow::ensure!(hook.prompt_calls.load(Ordering::SeqCst) == 1); + anyhow::ensure!(hook.response_calls.load(Ordering::SeqCst) == 1); let seen_prompt = hook .seen_prompt @@ -101,12 +101,12 @@ async fn request_hook_records_prompt_and_response() -> Result<()> { .map_err(|_| anyhow!("response hook state unavailable"))? .clone(); - assert!( + anyhow::ensure!( seen_prompt .as_deref() .is_some_and(|prompt| prompt.contains("Entertain me!")) ); - assert!( + anyhow::ensure!( seen_response .as_deref() .is_some_and(|captured| !captured.is_empty()) diff --git a/rig/rig-core/tests/copilot/extractor_usage.rs b/rig/rig-core/tests/copilot/extractor_usage.rs index 8c511f3ea..48d6de645 100644 --- a/rig/rig-core/tests/copilot/extractor_usage.rs +++ b/rig/rig-core/tests/copilot/extractor_usage.rs @@ -23,20 +23,21 @@ struct Address { zip_code: Option, } -fn assert_compatible_professions(left: Option<&str>, right: Option<&str>) { +fn assert_compatible_professions(left: Option<&str>, right: Option<&str>) -> Result<()> { let left = left - .expect("profession should be present") + .ok_or_else(|| anyhow::anyhow!("profession should be present"))? .trim() .to_ascii_lowercase(); let right = right - .expect("profession should be present") + .ok_or_else(|| anyhow::anyhow!("profession should be present"))? .trim() .to_ascii_lowercase(); - assert!( + anyhow::ensure!( left == right || left.contains(&right) || right.contains(&left), "expected compatible professions, got {left:?} and {right:?}" ); + Ok(()) } #[tokio::test] @@ -48,9 +49,9 @@ async fn extract_backward_compatibility() -> Result<()> { .extract("John Doe is a 30 year old software engineer.") .await?; - assert_eq!(person.name, Some("John Doe".to_string())); - assert_eq!(person.age, Some(30)); - assert_eq!(person.profession, Some("software engineer".to_string())); + anyhow::ensure!(person.name.as_deref() == Some("John Doe")); + anyhow::ensure!(person.age == Some(30)); + anyhow::ensure!(person.profession.as_deref() == Some("software engineer")); Ok(()) } @@ -64,12 +65,12 @@ async fn extract_with_usage_returns_data_and_usage() -> Result<()> { .extract_with_usage("Jane Smith is a 45 year old data scientist.") .await?; - assert_eq!(response.data.name, Some("Jane Smith".to_string())); - assert_eq!(response.data.age, Some(45)); - assert_eq!(response.data.profession, Some("data scientist".to_string())); - assert!(response.usage.input_tokens > 0); - assert!(response.usage.output_tokens > 0); - assert!(response.usage.total_tokens > 0); + anyhow::ensure!(response.data.name.as_deref() == Some("Jane Smith")); + anyhow::ensure!(response.data.age == Some(45)); + anyhow::ensure!(response.data.profession.as_deref() == Some("data scientist")); + anyhow::ensure!(response.usage.input_tokens > 0); + anyhow::ensure!(response.usage.output_tokens > 0); + anyhow::ensure!(response.usage.total_tokens > 0); Ok(()) } @@ -92,12 +93,12 @@ async fn extract_with_chat_history_with_usage_works() -> Result<()> { ) .await?; - assert_eq!(response.data.street, Some("123 Main St".to_string())); - assert_eq!(response.data.city, Some("Springfield".to_string())); - assert_eq!(response.data.state, Some("IL".to_string())); - assert_eq!(response.data.zip_code, Some("62701".to_string())); - assert!(response.usage.input_tokens > 0); - assert!(response.usage.total_tokens > 0); + anyhow::ensure!(response.data.street.as_deref() == Some("123 Main St")); + anyhow::ensure!(response.data.city.as_deref() == Some("Springfield")); + anyhow::ensure!(response.data.state.as_deref() == Some("IL")); + anyhow::ensure!(response.data.zip_code.as_deref() == Some("62701")); + anyhow::ensure!(response.usage.input_tokens > 0); + anyhow::ensure!(response.usage.total_tokens > 0); Ok(()) } @@ -112,15 +113,15 @@ async fn extract_and_extract_with_usage_return_same_data() -> Result<()> { let person = extractor.extract(text).await?; let response = extractor.extract_with_usage(text).await?; - assert_eq!(person.name, Some("Bob Johnson".to_string())); - assert_eq!(response.data.name, Some("Bob Johnson".to_string())); - assert_eq!(person.age, Some(55)); - assert_eq!(response.data.age, Some(55)); + anyhow::ensure!(person.name.as_deref() == Some("Bob Johnson")); + anyhow::ensure!(response.data.name.as_deref() == Some("Bob Johnson")); + anyhow::ensure!(person.age == Some(55)); + anyhow::ensure!(response.data.age == Some(55)); assert_compatible_professions( person.profession.as_deref(), response.data.profession.as_deref(), - ); - assert!(response.usage.total_tokens > 0, "usage should be populated"); + )?; + anyhow::ensure!(response.usage.total_tokens > 0, "usage should be populated"); Ok(()) } @@ -132,13 +133,13 @@ async fn usage_tracking_works_for_different_schemas() -> Result<()> { let person_response = person_extractor .extract_with_usage("Alice is a 25 year old developer.") .await?; - assert!(person_response.usage.total_tokens > 0); + anyhow::ensure!(person_response.usage.total_tokens > 0); let address_extractor = live_client().extractor::
(LIVE_LIGHT_MODEL).build(); let address_response = address_extractor .extract_with_usage("456 Oak Avenue, Cambridge, MA 02139") .await?; - assert!(address_response.usage.total_tokens > 0); + anyhow::ensure!(address_response.usage.total_tokens > 0); Ok(()) } diff --git a/rig/rig-core/tests/copilot/multi_extract.rs b/rig/rig-core/tests/copilot/multi_extract.rs index 27887403c..73a785ace 100644 --- a/rig/rig-core/tests/copilot/multi_extract.rs +++ b/rig/rig-core/tests/copilot/multi_extract.rs @@ -73,7 +73,7 @@ async fn batch_multi_extract_chain() -> Result<()> { ) .await?; - assert_eq!(responses.len(), 3); + anyhow::ensure!(responses.len() == 3); for response in responses { assert_nonempty_response(&response); } diff --git a/rig/rig-core/tests/copilot/permission_control.rs b/rig/rig-core/tests/copilot/permission_control.rs index 08e1acd7b..adf53d62b 100644 --- a/rig/rig-core/tests/copilot/permission_control.rs +++ b/rig/rig-core/tests/copilot/permission_control.rs @@ -179,8 +179,8 @@ async fn permission_control_prompt_example() -> Result<()> { .await?; let last = last_result.lock().expect("lock last_result").clone(); - assert_eq!(last.as_deref(), Some("hello world")); - assert!( + anyhow::ensure!(last.as_deref() == Some("hello world")); + anyhow::ensure!( call_count.load(Ordering::SeqCst) >= 2, "expected at least one skipped call followed by an allowed call" ); @@ -218,7 +218,7 @@ async fn permission_control_streaming_example() -> Result<()> { let final_response = stream_to_stdout(&mut stream).await?; let last = last_result.lock().expect("lock last_result").clone(); assert_nonempty_response(final_response.response()); - assert!( + anyhow::ensure!( final_response .response() .to_ascii_lowercase() @@ -226,8 +226,8 @@ async fn permission_control_streaming_example() -> Result<()> { "expected the streamed final response to mention the file content, got {:?}", final_response.response() ); - assert_eq!(last.as_deref(), Some("hello world")); - assert!( + anyhow::ensure!(last.as_deref() == Some("hello world")); + anyhow::ensure!( call_count.load(Ordering::SeqCst) >= 2, "expected at least one skipped call followed by an allowed call" ); diff --git a/rig/rig-core/tests/copilot/request_hook.rs b/rig/rig-core/tests/copilot/request_hook.rs index b1455cc3d..374cc68cd 100644 --- a/rig/rig-core/tests/copilot/request_hook.rs +++ b/rig/rig-core/tests/copilot/request_hook.rs @@ -87,8 +87,8 @@ async fn request_hook_records_prompt_and_response() -> Result<()> { .await?; assert_nonempty_response(&response); - assert_eq!(hook.prompt_calls.load(Ordering::SeqCst), 1); - assert_eq!(hook.response_calls.load(Ordering::SeqCst), 1); + anyhow::ensure!(hook.prompt_calls.load(Ordering::SeqCst) == 1); + anyhow::ensure!(hook.response_calls.load(Ordering::SeqCst) == 1); let seen_prompt = hook .seen_prompt @@ -101,12 +101,12 @@ async fn request_hook_records_prompt_and_response() -> Result<()> { .map_err(|_| anyhow!("response hook state unavailable"))? .clone(); - assert!( + anyhow::ensure!( seen_prompt .as_deref() .is_some_and(|prompt| prompt.contains("Entertain me!")) ); - assert!( + anyhow::ensure!( seen_response .as_deref() .is_some_and(|captured| !captured.is_empty()) diff --git a/rig/rig-core/tests/copilot/typed_prompt_tools.rs b/rig/rig-core/tests/copilot/typed_prompt_tools.rs index f7a75bb2a..e8cfa5066 100644 --- a/rig/rig-core/tests/copilot/typed_prompt_tools.rs +++ b/rig/rig-core/tests/copilot/typed_prompt_tools.rs @@ -83,7 +83,7 @@ async fn prompt_typed_with_tool_call_roundtrip() -> Result<()> { .prompt_typed("Hello, what's the weather in London?") .await?; - assert!( + anyhow::ensure!( call_count.load(Ordering::SeqCst) >= 1, "expected the weather tool to be executed at least once" ); diff --git a/rig/rig-core/tests/deepseek/extractor_usage.rs b/rig/rig-core/tests/deepseek/extractor_usage.rs index e799c527e..b2a0c4fd8 100644 --- a/rig/rig-core/tests/deepseek/extractor_usage.rs +++ b/rig/rig-core/tests/deepseek/extractor_usage.rs @@ -1,6 +1,6 @@ //! Integration tests for DeepSeek extractor usage tracking. -use anyhow::Result; +use anyhow::{Result, anyhow}; use rig::client::{CompletionClient, ProviderClient}; use rig::extractor::ExtractionResponse; use rig::message::Message; @@ -23,17 +23,18 @@ struct Address { zip_code: Option, } -fn assert_compatible_professions(left: Option<&str>, right: &str) { +fn assert_compatible_professions(left: Option<&str>, right: &str) -> Result<()> { let left = left - .expect("profession should be present") + .ok_or_else(|| anyhow!("profession should be present"))? .trim() .to_ascii_lowercase(); let right = right.trim().to_ascii_lowercase(); - assert!( + anyhow::ensure!( left == right || left.contains(&right) || right.contains(&left), "expected compatible professions, got {left:?} and {right:?}" ); + Ok(()) } #[tokio::test] @@ -46,9 +47,17 @@ async fn extract_backward_compatibility() -> Result<()> { .extract("John Doe is a 30 year old software engineer.") .await?; - assert_eq!(person.name, Some("John Doe".to_string())); - assert_eq!(person.age, Some(30)); - assert_compatible_professions(person.profession.as_deref(), "software engineer"); + anyhow::ensure!( + person.name == Some("John Doe".to_string()), + "expected name John Doe, got {:?}", + person.name + ); + anyhow::ensure!( + person.age == Some(30), + "expected age 30, got {:?}", + person.age + ); + assert_compatible_professions(person.profession.as_deref(), "software engineer")?; Ok(()) } @@ -63,12 +72,20 @@ async fn extract_with_usage_returns_data_and_usage() -> Result<()> { .extract_with_usage("Jane Smith is a 45 year old data scientist.") .await?; - assert_eq!(response.data.name, Some("Jane Smith".to_string())); - assert_eq!(response.data.age, Some(45)); - assert_compatible_professions(response.data.profession.as_deref(), "data scientist"); - assert!(response.usage.input_tokens > 0); - assert!(response.usage.output_tokens > 0); - assert!(response.usage.total_tokens > 0); + anyhow::ensure!( + response.data.name == Some("Jane Smith".to_string()), + "expected name Jane Smith, got {:?}", + response.data.name + ); + anyhow::ensure!( + response.data.age == Some(45), + "expected age 45, got {:?}", + response.data.age + ); + assert_compatible_professions(response.data.profession.as_deref(), "data scientist")?; + anyhow::ensure!(response.usage.input_tokens > 0, "expected input tokens"); + anyhow::ensure!(response.usage.output_tokens > 0, "expected output tokens"); + anyhow::ensure!(response.usage.total_tokens > 0, "expected total tokens"); Ok(()) } @@ -90,12 +107,28 @@ async fn extract_with_chat_history_with_usage_works() -> Result<()> { ) .await?; - assert_eq!(response.data.street, Some("123 Main St".to_string())); - assert_eq!(response.data.city, Some("Springfield".to_string())); - assert_eq!(response.data.state, Some("IL".to_string())); - assert_eq!(response.data.zip_code, Some("62701".to_string())); - assert!(response.usage.input_tokens > 0); - assert!(response.usage.total_tokens > 0); + anyhow::ensure!( + response.data.street == Some("123 Main St".to_string()), + "expected street 123 Main St, got {:?}", + response.data.street + ); + anyhow::ensure!( + response.data.city == Some("Springfield".to_string()), + "expected city Springfield, got {:?}", + response.data.city + ); + anyhow::ensure!( + response.data.state == Some("IL".to_string()), + "expected state IL, got {:?}", + response.data.state + ); + anyhow::ensure!( + response.data.zip_code == Some("62701".to_string()), + "expected zip code 62701, got {:?}", + response.data.zip_code + ); + anyhow::ensure!(response.usage.input_tokens > 0, "expected input tokens"); + anyhow::ensure!(response.usage.total_tokens > 0, "expected total tokens"); Ok(()) } @@ -110,13 +143,29 @@ async fn extract_and_extract_with_usage_return_same_data() -> Result<()> { let person = extractor.extract(text).await?; let response = extractor.extract_with_usage(text).await?; - assert_eq!(person.name, Some("Bob Johnson".to_string())); - assert_eq!(response.data.name, Some("Bob Johnson".to_string())); - assert_eq!(person.age, Some(55)); - assert_eq!(response.data.age, Some(55)); - assert_compatible_professions(person.profession.as_deref(), "retired teacher"); - assert_compatible_professions(response.data.profession.as_deref(), "retired teacher"); - assert!(response.usage.total_tokens > 0, "usage should be populated"); + anyhow::ensure!( + person.name == Some("Bob Johnson".to_string()), + "expected extracted name Bob Johnson, got {:?}", + person.name + ); + anyhow::ensure!( + response.data.name == Some("Bob Johnson".to_string()), + "expected usage response name Bob Johnson, got {:?}", + response.data.name + ); + anyhow::ensure!( + person.age == Some(55), + "expected extracted age 55, got {:?}", + person.age + ); + anyhow::ensure!( + response.data.age == Some(55), + "expected usage response age 55, got {:?}", + response.data.age + ); + assert_compatible_professions(person.profession.as_deref(), "retired teacher")?; + assert_compatible_professions(response.data.profession.as_deref(), "retired teacher")?; + anyhow::ensure!(response.usage.total_tokens > 0, "usage should be populated"); Ok(()) } @@ -130,13 +179,19 @@ async fn usage_tracking_works_for_different_schemas() -> Result<()> { let person_response = person_extractor .extract_with_usage("Alice is a 25 year old developer.") .await?; - assert!(person_response.usage.total_tokens > 0); + anyhow::ensure!( + person_response.usage.total_tokens > 0, + "expected person usage tokens" + ); let address_extractor = client.extractor::
(deepseek::DEEPSEEK_CHAT).build(); let address_response = address_extractor .extract_with_usage("456 Oak Avenue, Cambridge, MA 02139") .await?; - assert!(address_response.usage.total_tokens > 0); + anyhow::ensure!( + address_response.usage.total_tokens > 0, + "expected address usage tokens" + ); Ok(()) } diff --git a/rig/rig-core/tests/deepseek/multi_extract.rs b/rig/rig-core/tests/deepseek/multi_extract.rs index 9a5cdeba4..4e0db7ce5 100644 --- a/rig/rig-core/tests/deepseek/multi_extract.rs +++ b/rig/rig-core/tests/deepseek/multi_extract.rs @@ -73,7 +73,11 @@ async fn batch_multi_extract_chain() -> Result<()> { ) .await?; - assert_eq!(responses.len(), 3); + anyhow::ensure!( + responses.len() == 3, + "expected three responses, got {}", + responses.len() + ); for response in responses { assert_nonempty_response(&response); } diff --git a/rig/rig-core/tests/deepseek/permission_control.rs b/rig/rig-core/tests/deepseek/permission_control.rs index 714683d73..b83f9f84e 100644 --- a/rig/rig-core/tests/deepseek/permission_control.rs +++ b/rig/rig-core/tests/deepseek/permission_control.rs @@ -178,8 +178,14 @@ async fn permission_control_prompt_example() -> Result<()> { .await?; let last = last_result.lock().expect("lock last_result").clone(); - assert_eq!(last.as_deref(), Some("hello world")); - assert_eq!(call_count.load(Ordering::SeqCst), 2); + anyhow::ensure!( + last.as_deref() == Some("hello world"), + "expected final tool result hello world, got {last:?}" + ); + anyhow::ensure!( + call_count.load(Ordering::SeqCst) == 2, + "expected two tool hook calls" + ); Ok(()) } @@ -216,7 +222,7 @@ async fn permission_control_streaming_example() -> Result<()> { let final_response = stream_to_stdout(&mut stream).await?; let last = last_result.lock().expect("lock last_result").clone(); assert_nonempty_response(final_response.response()); - assert!( + anyhow::ensure!( final_response .response() .to_ascii_lowercase() @@ -224,8 +230,14 @@ async fn permission_control_streaming_example() -> Result<()> { "expected the streamed final response to mention the file content, got {:?}", final_response.response() ); - assert_eq!(last.as_deref(), Some("hello world")); - assert_eq!(call_count.load(Ordering::SeqCst), 2); + anyhow::ensure!( + last.as_deref() == Some("hello world"), + "expected final tool result hello world, got {last:?}" + ); + anyhow::ensure!( + call_count.load(Ordering::SeqCst) == 2, + "expected two tool hook calls" + ); Ok(()) } diff --git a/rig/rig-core/tests/deepseek/request_hook.rs b/rig/rig-core/tests/deepseek/request_hook.rs index ebc78cfd9..8e1e080d7 100644 --- a/rig/rig-core/tests/deepseek/request_hook.rs +++ b/rig/rig-core/tests/deepseek/request_hook.rs @@ -88,8 +88,14 @@ async fn request_hook_records_prompt_and_response() -> Result<()> { .await?; assert_nonempty_response(&response); - assert_eq!(hook.prompt_calls.load(Ordering::SeqCst), 1); - assert_eq!(hook.response_calls.load(Ordering::SeqCst), 1); + anyhow::ensure!( + hook.prompt_calls.load(Ordering::SeqCst) == 1, + "expected one prompt hook call" + ); + anyhow::ensure!( + hook.response_calls.load(Ordering::SeqCst) == 1, + "expected one response hook call" + ); let seen_prompt = hook .seen_prompt @@ -102,15 +108,17 @@ async fn request_hook_records_prompt_and_response() -> Result<()> { .map_err(|_| anyhow!("response hook state unavailable"))? .clone(); - assert!( + anyhow::ensure!( seen_prompt .as_deref() - .is_some_and(|prompt| prompt.contains("Entertain me!")) + .is_some_and(|prompt| prompt.contains("Entertain me!")), + "expected hook to capture prompt text" ); - assert!( + anyhow::ensure!( seen_response .as_deref() - .is_some_and(|captured| !captured.is_empty()) + .is_some_and(|captured| !captured.is_empty()), + "expected hook to capture response text" ); Ok(()) diff --git a/rig/rig-core/tests/groq/extractor_usage.rs b/rig/rig-core/tests/groq/extractor_usage.rs index c341647a7..e688c8b60 100644 --- a/rig/rig-core/tests/groq/extractor_usage.rs +++ b/rig/rig-core/tests/groq/extractor_usage.rs @@ -1,6 +1,6 @@ //! Integration tests for Groq extractor usage tracking. -use anyhow::Result; +use anyhow::{Result, anyhow}; use rig::client::{CompletionClient, ProviderClient}; use rig::extractor::ExtractionResponse; use rig::message::Message; @@ -29,17 +29,18 @@ struct Address { zip_code: Option, } -fn assert_compatible_professions(left: Option<&str>, right: &str) { +fn assert_compatible_professions(left: Option<&str>, right: &str) -> Result<()> { let left = left - .expect("profession should be present") + .ok_or_else(|| anyhow!("profession should be present"))? .trim() .to_ascii_lowercase(); let right = right.trim().to_ascii_lowercase(); - assert!( + anyhow::ensure!( left == right || left.contains(&right) || right.contains(&left), "expected compatible professions, got {left:?} and {right:?}" ); + Ok(()) } #[tokio::test] @@ -54,9 +55,9 @@ async fn extract_backward_compatibility() -> Result<()> { .extract("John Doe is a 30 year old software engineer.") .await?; - assert_eq!(person.name, Some("John Doe".to_string())); - assert_eq!(person.age, Some(30)); - assert_compatible_professions(person.profession.as_deref(), "software engineer"); + anyhow::ensure!(person.name.as_deref() == Some("John Doe")); + anyhow::ensure!(person.age == Some(30)); + assert_compatible_professions(person.profession.as_deref(), "software engineer")?; Ok(()) } @@ -73,12 +74,12 @@ async fn extract_with_usage_returns_data_and_usage() -> Result<()> { .extract_with_usage("Jane Smith is a 45 year old data scientist.") .await?; - assert_eq!(response.data.name, Some("Jane Smith".to_string())); - assert_eq!(response.data.age, Some(45)); - assert_compatible_professions(response.data.profession.as_deref(), "data scientist"); - assert!(response.usage.input_tokens > 0); - assert!(response.usage.output_tokens > 0); - assert!(response.usage.total_tokens > 0); + anyhow::ensure!(response.data.name.as_deref() == Some("Jane Smith")); + anyhow::ensure!(response.data.age == Some(45)); + assert_compatible_professions(response.data.profession.as_deref(), "data scientist")?; + anyhow::ensure!(response.usage.input_tokens > 0); + anyhow::ensure!(response.usage.output_tokens > 0); + anyhow::ensure!(response.usage.total_tokens > 0); Ok(()) } @@ -102,12 +103,12 @@ async fn extract_with_chat_history_with_usage_works() -> Result<()> { ) .await?; - assert_eq!(response.data.street, Some("123 Main St".to_string())); - assert_eq!(response.data.city, Some("Springfield".to_string())); - assert_eq!(response.data.state, Some("IL".to_string())); - assert_eq!(response.data.zip_code, Some("62701".to_string())); - assert!(response.usage.input_tokens > 0); - assert!(response.usage.total_tokens > 0); + anyhow::ensure!(response.data.street.as_deref() == Some("123 Main St")); + anyhow::ensure!(response.data.city.as_deref() == Some("Springfield")); + anyhow::ensure!(response.data.state.as_deref() == Some("IL")); + anyhow::ensure!(response.data.zip_code.as_deref() == Some("62701")); + anyhow::ensure!(response.usage.input_tokens > 0); + anyhow::ensure!(response.usage.total_tokens > 0); Ok(()) } @@ -124,13 +125,13 @@ async fn extract_and_extract_with_usage_return_same_data() -> Result<()> { let person = extractor.extract(text).await?; let response = extractor.extract_with_usage(text).await?; - assert_eq!(person.name, Some("Bob Johnson".to_string())); - assert_eq!(response.data.name, Some("Bob Johnson".to_string())); - assert_eq!(person.age, Some(55)); - assert_eq!(response.data.age, Some(55)); - assert_compatible_professions(person.profession.as_deref(), "retired teacher"); - assert_compatible_professions(response.data.profession.as_deref(), "retired teacher"); - assert!(response.usage.total_tokens > 0, "usage should be populated"); + anyhow::ensure!(person.name.as_deref() == Some("Bob Johnson")); + anyhow::ensure!(response.data.name.as_deref() == Some("Bob Johnson")); + anyhow::ensure!(person.age == Some(55)); + anyhow::ensure!(response.data.age == Some(55)); + assert_compatible_professions(person.profession.as_deref(), "retired teacher")?; + assert_compatible_professions(response.data.profession.as_deref(), "retired teacher")?; + anyhow::ensure!(response.usage.total_tokens > 0, "usage should be populated"); Ok(()) } @@ -146,7 +147,7 @@ async fn usage_tracking_works_for_different_schemas() -> Result<()> { let person_response = person_extractor .extract_with_usage("Alice is a 25 year old developer.") .await?; - assert!(person_response.usage.total_tokens > 0); + anyhow::ensure!(person_response.usage.total_tokens > 0); let address_extractor = client .extractor::
(EXTRACTOR_USAGE_TRACKING_MODEL) @@ -154,7 +155,7 @@ async fn usage_tracking_works_for_different_schemas() -> Result<()> { let address_response = address_extractor .extract_with_usage("456 Oak Avenue, Cambridge, MA 02139") .await?; - assert!(address_response.usage.total_tokens > 0); + anyhow::ensure!(address_response.usage.total_tokens > 0); Ok(()) } diff --git a/rig/rig-core/tests/groq/multi_extract.rs b/rig/rig-core/tests/groq/multi_extract.rs index e461e236f..dfabcf792 100644 --- a/rig/rig-core/tests/groq/multi_extract.rs +++ b/rig/rig-core/tests/groq/multi_extract.rs @@ -75,7 +75,11 @@ async fn batch_multi_extract_chain() -> Result<()> { ) .await?; - assert_eq!(responses.len(), 3); + anyhow::ensure!( + responses.len() == 3, + "expected three responses, got {}", + responses.len() + ); for response in responses { assert_nonempty_response(&response); } diff --git a/rig/rig-core/tests/groq/permission_control.rs b/rig/rig-core/tests/groq/permission_control.rs index 77968699c..1187b8169 100644 --- a/rig/rig-core/tests/groq/permission_control.rs +++ b/rig/rig-core/tests/groq/permission_control.rs @@ -182,8 +182,8 @@ async fn permission_control_prompt_example() -> Result<()> { .await?; let last = last_result.lock().expect("lock last_result").clone(); - assert_eq!(last.as_deref(), Some("hello world")); - assert!( + anyhow::ensure!(last.as_deref() == Some("hello world")); + anyhow::ensure!( call_count.load(Ordering::SeqCst) >= 2, "expected at least one skipped tool call followed by a successful retry" ); @@ -225,7 +225,7 @@ async fn permission_control_streaming_example() -> Result<()> { let final_response = stream_to_stdout(&mut stream).await?; let last = last_result.lock().expect("lock last_result").clone(); assert_nonempty_response(final_response.response()); - assert!( + anyhow::ensure!( final_response .response() .to_ascii_lowercase() @@ -233,8 +233,8 @@ async fn permission_control_streaming_example() -> Result<()> { "expected the streamed final response to mention the file content, got {:?}", final_response.response() ); - assert_eq!(last.as_deref(), Some("hello world")); - assert!( + anyhow::ensure!(last.as_deref() == Some("hello world")); + anyhow::ensure!( call_count.load(Ordering::SeqCst) >= 2, "expected at least one skipped tool call followed by a successful retry" ); diff --git a/rig/rig-core/tests/groq/request_hook.rs b/rig/rig-core/tests/groq/request_hook.rs index c44a9f028..2a9525eaa 100644 --- a/rig/rig-core/tests/groq/request_hook.rs +++ b/rig/rig-core/tests/groq/request_hook.rs @@ -90,8 +90,8 @@ async fn request_hook_records_prompt_and_response() -> Result<()> { .await?; assert_nonempty_response(&response); - assert_eq!(hook.prompt_calls.load(Ordering::SeqCst), 1); - assert_eq!(hook.response_calls.load(Ordering::SeqCst), 1); + anyhow::ensure!(hook.prompt_calls.load(Ordering::SeqCst) == 1); + anyhow::ensure!(hook.response_calls.load(Ordering::SeqCst) == 1); let seen_prompt = hook .seen_prompt @@ -104,12 +104,12 @@ async fn request_hook_records_prompt_and_response() -> Result<()> { .map_err(|_| anyhow!("response hook state unavailable"))? .clone(); - assert!( + anyhow::ensure!( seen_prompt .as_deref() .is_some_and(|prompt| prompt.contains("Entertain me!")) ); - assert!( + anyhow::ensure!( seen_response .as_deref() .is_some_and(|captured| !captured.is_empty()) diff --git a/rig/rig-core/tests/groq/typed_prompt_tools.rs b/rig/rig-core/tests/groq/typed_prompt_tools.rs index 7988ac36a..18910f60a 100644 --- a/rig/rig-core/tests/groq/typed_prompt_tools.rs +++ b/rig/rig-core/tests/groq/typed_prompt_tools.rs @@ -99,7 +99,7 @@ async fn prompt_typed_with_tool_call_roundtrip() -> Result<()> { .prompt_typed("Hello, whats the weather in London?") .await?; - assert!( + anyhow::ensure!( call_count.load(Ordering::SeqCst) >= 1, "expected the weather tool to be executed at least once" ); diff --git a/rig/rig-core/tests/llamacpp/extractor_usage.rs b/rig/rig-core/tests/llamacpp/extractor_usage.rs index 3ab350a6b..64443371f 100644 --- a/rig/rig-core/tests/llamacpp/extractor_usage.rs +++ b/rig/rig-core/tests/llamacpp/extractor_usage.rs @@ -1,6 +1,6 @@ //! Integration tests for llama.cpp extractor usage tracking. -use anyhow::Result; +use anyhow::{Result, anyhow}; use rig::extractor::ExtractionResponse; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -35,20 +35,21 @@ Extract every field explicitly stated in the input text. Do not omit keys when the value is present in the text. Return the exact stated values through the submit tool."; -fn assert_compatible_professions(left: Option<&str>, right: Option<&str>) { +fn assert_compatible_professions(left: Option<&str>, right: Option<&str>) -> Result<()> { let left = left - .expect("profession should be present") + .ok_or_else(|| anyhow!("profession should be present"))? .trim() .to_ascii_lowercase(); let right = right - .expect("profession should be present") + .ok_or_else(|| anyhow!("profession should be present"))? .trim() .to_ascii_lowercase(); - assert!( + anyhow::ensure!( left == right || left.contains(&right) || right.contains(&left), "expected compatible professions, got {left:?} and {right:?}" ); + Ok(()) } #[tokio::test] @@ -66,9 +67,9 @@ async fn extract_backward_compatibility() -> Result<()> { .extract("John Doe is a 30 year old software engineer.") .await?; - assert_eq!(person.name, Some("John Doe".to_string())); - assert_eq!(person.age, Some(30)); - assert_eq!(person.profession, Some("software engineer".to_string())); + anyhow::ensure!(person.name.as_deref() == Some("John Doe")); + anyhow::ensure!(person.age == Some(30)); + anyhow::ensure!(person.profession.as_deref() == Some("software engineer")); Ok(()) } @@ -88,12 +89,12 @@ async fn extract_with_usage_returns_data_and_usage() -> Result<()> { .extract_with_usage("Jane Smith is a 45 year old data scientist.") .await?; - assert_eq!(response.data.name, Some("Jane Smith".to_string())); - assert_eq!(response.data.age, Some(45)); - assert_eq!(response.data.profession, Some("data scientist".to_string())); - assert!(response.usage.input_tokens > 0); - assert!(response.usage.output_tokens > 0); - assert!(response.usage.total_tokens > 0); + anyhow::ensure!(response.data.name.as_deref() == Some("Jane Smith")); + anyhow::ensure!(response.data.age == Some(45)); + anyhow::ensure!(response.data.profession.as_deref() == Some("data scientist")); + anyhow::ensure!(response.usage.input_tokens > 0); + anyhow::ensure!(response.usage.output_tokens > 0); + anyhow::ensure!(response.usage.total_tokens > 0); Ok(()) } @@ -122,12 +123,12 @@ async fn extract_with_chat_history_with_usage_works() -> Result<()> { ) .await?; - assert_eq!(response.data.street, Some("123 Main St".to_string())); - assert_eq!(response.data.city, Some("Springfield".to_string())); - assert_eq!(response.data.state, Some("IL".to_string())); - assert_eq!(response.data.zip_code, Some("62701".to_string())); - assert!(response.usage.input_tokens > 0); - assert!(response.usage.total_tokens > 0); + anyhow::ensure!(response.data.street.as_deref() == Some("123 Main St")); + anyhow::ensure!(response.data.city.as_deref() == Some("Springfield")); + anyhow::ensure!(response.data.state.as_deref() == Some("IL")); + anyhow::ensure!(response.data.zip_code.as_deref() == Some("62701")); + anyhow::ensure!(response.usage.input_tokens > 0); + anyhow::ensure!(response.usage.total_tokens > 0); Ok(()) } @@ -147,15 +148,15 @@ async fn extract_and_extract_with_usage_return_same_data() -> Result<()> { let person = extractor.extract(text).await?; let response = extractor.extract_with_usage(text).await?; - assert_eq!(person.name, Some("Bob Johnson".to_string())); - assert_eq!(response.data.name, Some("Bob Johnson".to_string())); - assert_eq!(person.age, Some(55)); - assert_eq!(response.data.age, Some(55)); + anyhow::ensure!(person.name.as_deref() == Some("Bob Johnson")); + anyhow::ensure!(response.data.name.as_deref() == Some("Bob Johnson")); + anyhow::ensure!(person.age == Some(55)); + anyhow::ensure!(response.data.age == Some(55)); assert_compatible_professions( person.profession.as_deref(), response.data.profession.as_deref(), - ); - assert!(response.usage.total_tokens > 0, "usage should be populated"); + )?; + anyhow::ensure!(response.usage.total_tokens > 0, "usage should be populated"); Ok(()) } @@ -175,7 +176,7 @@ async fn usage_tracking_works_for_different_schemas() -> Result<()> { .extract_with_usage("Alice is a 25 year old developer.") .await?; - assert!(person_response.usage.total_tokens > 0); + anyhow::ensure!(person_response.usage.total_tokens > 0); let address_extractor = client .extractor::
(model) @@ -186,7 +187,7 @@ async fn usage_tracking_works_for_different_schemas() -> Result<()> { .extract_with_usage("456 Oak Avenue, Cambridge, MA 02139") .await?; - assert!(address_response.usage.total_tokens > 0); + anyhow::ensure!(address_response.usage.total_tokens > 0); Ok(()) } diff --git a/rig/rig-core/tests/llamacpp/multi_extract.rs b/rig/rig-core/tests/llamacpp/multi_extract.rs index 519824c46..d33bbf6a2 100644 --- a/rig/rig-core/tests/llamacpp/multi_extract.rs +++ b/rig/rig-core/tests/llamacpp/multi_extract.rs @@ -74,7 +74,11 @@ async fn batch_multi_extract_chain() -> Result<()> { ) .await?; - assert_eq!(responses.len(), 3); + anyhow::ensure!( + responses.len() == 3, + "expected three responses, got {}", + responses.len() + ); for response in responses { assert_nonempty_response(&response); } diff --git a/rig/rig-core/tests/llamacpp/permission_control.rs b/rig/rig-core/tests/llamacpp/permission_control.rs index 591d40bbf..9d0524fe1 100644 --- a/rig/rig-core/tests/llamacpp/permission_control.rs +++ b/rig/rig-core/tests/llamacpp/permission_control.rs @@ -176,8 +176,8 @@ async fn permission_control_prompt_example() -> Result<()> { .await?; let last = last_result.lock().expect("lock last_result").clone(); - assert_eq!(last.as_deref(), Some("hello world")); - assert_eq!(call_count.load(Ordering::SeqCst), 2); + anyhow::ensure!(last.as_deref() == Some("hello world")); + anyhow::ensure!(call_count.load(Ordering::SeqCst) == 2); Ok(()) } @@ -212,7 +212,7 @@ async fn permission_control_streaming_example() -> Result<()> { let final_response = stream_to_stdout(&mut stream).await?; let last = last_result.lock().expect("lock last_result").clone(); assert_nonempty_response(final_response.response()); - assert!( + anyhow::ensure!( final_response .response() .to_ascii_lowercase() @@ -220,8 +220,8 @@ async fn permission_control_streaming_example() -> Result<()> { "expected the streamed final response to mention the file content, got {:?}", final_response.response() ); - assert_eq!(last.as_deref(), Some("hello world")); - assert_eq!(call_count.load(Ordering::SeqCst), 2); + anyhow::ensure!(last.as_deref() == Some("hello world")); + anyhow::ensure!(call_count.load(Ordering::SeqCst) == 2); Ok(()) } diff --git a/rig/rig-core/tests/llamacpp/request_hook.rs b/rig/rig-core/tests/llamacpp/request_hook.rs index 07c845739..e9626c2c3 100644 --- a/rig/rig-core/tests/llamacpp/request_hook.rs +++ b/rig/rig-core/tests/llamacpp/request_hook.rs @@ -88,8 +88,8 @@ async fn request_hook_records_prompt_and_response() -> Result<()> { .await?; assert_nonempty_response(&response); - assert_eq!(hook.prompt_calls.load(Ordering::SeqCst), 1); - assert_eq!(hook.response_calls.load(Ordering::SeqCst), 1); + anyhow::ensure!(hook.prompt_calls.load(Ordering::SeqCst) == 1); + anyhow::ensure!(hook.response_calls.load(Ordering::SeqCst) == 1); let seen_prompt = hook .seen_prompt @@ -102,12 +102,12 @@ async fn request_hook_records_prompt_and_response() -> Result<()> { .map_err(|_| anyhow!("response hook state unavailable"))? .clone(); - assert!( + anyhow::ensure!( seen_prompt .as_deref() .is_some_and(|prompt| prompt.contains("Entertain me!")) ); - assert!( + anyhow::ensure!( seen_response .as_deref() .is_some_and(|captured| !captured.is_empty()) diff --git a/rig/rig-core/tests/llamacpp/typed_prompt_tools.rs b/rig/rig-core/tests/llamacpp/typed_prompt_tools.rs index aeec0534a..1597f72e8 100644 --- a/rig/rig-core/tests/llamacpp/typed_prompt_tools.rs +++ b/rig/rig-core/tests/llamacpp/typed_prompt_tools.rs @@ -201,7 +201,7 @@ async fn prompt_typed_with_tool_call_verbatim_roundtrip() -> Result<()> { let response = result?; println!("agent response: {response:#?}"); - assert!( + anyhow::ensure!( call_count.load(Ordering::SeqCst) >= 1, "expected the weather tool to be executed at least once" ); diff --git a/rig/rig-core/tests/llamafile/extractor_usage.rs b/rig/rig-core/tests/llamafile/extractor_usage.rs index 34cd5904a..057a0b4b0 100644 --- a/rig/rig-core/tests/llamafile/extractor_usage.rs +++ b/rig/rig-core/tests/llamafile/extractor_usage.rs @@ -24,17 +24,18 @@ struct Address { zip_code: Option, } -fn assert_compatible_professions(left: Option<&str>, right: &str) { +fn assert_compatible_professions(left: Option<&str>, right: &str) -> Result<()> { let left = left - .expect("profession should be present") + .ok_or_else(|| anyhow::anyhow!("profession should be present"))? .trim() .to_ascii_lowercase(); let right = right.trim().to_ascii_lowercase(); - assert!( + anyhow::ensure!( left == right || left.contains(&right) || right.contains(&left), "expected compatible professions, got {left:?} and {right:?}" ); + Ok(()) } #[tokio::test] @@ -52,9 +53,9 @@ async fn extract_backward_compatibility() -> Result<()> { .extract("John Doe is a 30 year old software engineer.") .await?; - assert_eq!(person.name, Some("John Doe".to_string())); - assert_eq!(person.age, Some(30)); - assert_compatible_professions(person.profession.as_deref(), "software engineer"); + anyhow::ensure!(person.name.as_deref() == Some("John Doe")); + anyhow::ensure!(person.age == Some(30)); + assert_compatible_professions(person.profession.as_deref(), "software engineer")?; Ok(()) } @@ -74,12 +75,12 @@ async fn extract_with_usage_returns_data_and_usage() -> Result<()> { .extract_with_usage("Jane Smith is a 45 year old data scientist.") .await?; - assert_eq!(response.data.name, Some("Jane Smith".to_string())); - assert_eq!(response.data.age, Some(45)); - assert_compatible_professions(response.data.profession.as_deref(), "data scientist"); - assert!(response.usage.input_tokens > 0); - assert!(response.usage.output_tokens > 0); - assert!(response.usage.total_tokens > 0); + anyhow::ensure!(response.data.name.as_deref() == Some("Jane Smith")); + anyhow::ensure!(response.data.age == Some(45)); + assert_compatible_professions(response.data.profession.as_deref(), "data scientist")?; + anyhow::ensure!(response.usage.input_tokens > 0); + anyhow::ensure!(response.usage.output_tokens > 0); + anyhow::ensure!(response.usage.total_tokens > 0); Ok(()) } @@ -106,12 +107,12 @@ async fn extract_with_chat_history_with_usage_works() -> Result<()> { ) .await?; - assert_eq!(response.data.street, Some("123 Main St".to_string())); - assert_eq!(response.data.city, Some("Springfield".to_string())); - assert_eq!(response.data.state, Some("IL".to_string())); - assert_eq!(response.data.zip_code, Some("62701".to_string())); - assert!(response.usage.input_tokens > 0); - assert!(response.usage.total_tokens > 0); + anyhow::ensure!(response.data.street.as_deref() == Some("123 Main St")); + anyhow::ensure!(response.data.city.as_deref() == Some("Springfield")); + anyhow::ensure!(response.data.state.as_deref() == Some("IL")); + anyhow::ensure!(response.data.zip_code.as_deref() == Some("62701")); + anyhow::ensure!(response.usage.input_tokens > 0); + anyhow::ensure!(response.usage.total_tokens > 0); Ok(()) } @@ -131,13 +132,13 @@ async fn extract_and_extract_with_usage_return_same_data() -> Result<()> { let person = extractor.extract(text).await?; let response = extractor.extract_with_usage(text).await?; - assert_eq!(person.name, Some("Bob Johnson".to_string())); - assert_eq!(response.data.name, Some("Bob Johnson".to_string())); - assert_eq!(person.age, Some(55)); - assert_eq!(response.data.age, Some(55)); - assert_compatible_professions(person.profession.as_deref(), "retired teacher"); - assert_compatible_professions(response.data.profession.as_deref(), "retired teacher"); - assert!(response.usage.total_tokens > 0, "usage should be populated"); + anyhow::ensure!(person.name.as_deref() == Some("Bob Johnson")); + anyhow::ensure!(response.data.name.as_deref() == Some("Bob Johnson")); + anyhow::ensure!(person.age == Some(55)); + anyhow::ensure!(response.data.age == Some(55)); + assert_compatible_professions(person.profession.as_deref(), "retired teacher")?; + assert_compatible_professions(response.data.profession.as_deref(), "retired teacher")?; + anyhow::ensure!(response.usage.total_tokens > 0, "usage should be populated"); Ok(()) } @@ -156,13 +157,13 @@ async fn usage_tracking_works_for_different_schemas() -> Result<()> { let person_response = person_extractor .extract_with_usage("Alice is a 25 year old developer.") .await?; - assert!(person_response.usage.total_tokens > 0); + anyhow::ensure!(person_response.usage.total_tokens > 0); let address_extractor = client.extractor::
(model).build(); let address_response = address_extractor .extract_with_usage("456 Oak Avenue, Cambridge, MA 02139") .await?; - assert!(address_response.usage.total_tokens > 0); + anyhow::ensure!(address_response.usage.total_tokens > 0); Ok(()) } diff --git a/rig/rig-core/tests/llamafile/multi_extract.rs b/rig/rig-core/tests/llamafile/multi_extract.rs index 8af3047f0..e6f774afd 100644 --- a/rig/rig-core/tests/llamafile/multi_extract.rs +++ b/rig/rig-core/tests/llamafile/multi_extract.rs @@ -82,7 +82,7 @@ async fn batch_multi_extract_chain() -> Result<()> { ) .await?; - assert_eq!(responses.len(), 3); + anyhow::ensure!(responses.len() == 3); for response in responses { assert_nonempty_response(&response); } diff --git a/rig/rig-core/tests/llamafile/permission_control.rs b/rig/rig-core/tests/llamafile/permission_control.rs index 0290bb3cc..0157d45c3 100644 --- a/rig/rig-core/tests/llamafile/permission_control.rs +++ b/rig/rig-core/tests/llamafile/permission_control.rs @@ -232,8 +232,8 @@ async fn permission_control_prompt_example() -> Result<()> { ); return Ok(()); } - assert_eq!(last.as_deref(), Some("hello world")); - assert!( + anyhow::ensure!(last.as_deref() == Some("hello world")); + anyhow::ensure!( call_count.load(Ordering::SeqCst) >= 2, "expected at least one skipped tool call followed by a successful retry" ); @@ -289,7 +289,7 @@ async fn permission_control_streaming_example() -> Result<()> { return Ok(()); } assert_nonempty_response(final_response.response()); - assert!( + anyhow::ensure!( final_response .response() .to_ascii_lowercase() @@ -297,8 +297,8 @@ async fn permission_control_streaming_example() -> Result<()> { "expected the streamed final response to mention the file content, got {:?}", final_response.response() ); - assert_eq!(last.as_deref(), Some("hello world")); - assert!( + anyhow::ensure!(last.as_deref() == Some("hello world")); + anyhow::ensure!( call_count.load(Ordering::SeqCst) >= 2, "expected at least one skipped tool call followed by a successful retry" ); diff --git a/rig/rig-core/tests/llamafile/request_hook.rs b/rig/rig-core/tests/llamafile/request_hook.rs index 60cb92b4d..6a8f9706d 100644 --- a/rig/rig-core/tests/llamafile/request_hook.rs +++ b/rig/rig-core/tests/llamafile/request_hook.rs @@ -92,8 +92,8 @@ async fn request_hook_records_prompt_and_response() -> Result<()> { .await?; assert_nonempty_response(&response); - assert_eq!(hook.prompt_calls.load(Ordering::SeqCst), 1); - assert_eq!(hook.response_calls.load(Ordering::SeqCst), 1); + anyhow::ensure!(hook.prompt_calls.load(Ordering::SeqCst) == 1); + anyhow::ensure!(hook.response_calls.load(Ordering::SeqCst) == 1); let seen_prompt = hook .seen_prompt @@ -106,12 +106,12 @@ async fn request_hook_records_prompt_and_response() -> Result<()> { .map_err(|_| anyhow!("response hook state unavailable"))? .clone(); - assert!( + anyhow::ensure!( seen_prompt .as_deref() .is_some_and(|prompt| prompt.contains("Entertain me!")) ); - assert!( + anyhow::ensure!( seen_response .as_deref() .is_some_and(|captured| !captured.is_empty()) diff --git a/rig/rig-core/tests/llamafile/typed_prompt_tools.rs b/rig/rig-core/tests/llamafile/typed_prompt_tools.rs index e6922462a..0d680af42 100644 --- a/rig/rig-core/tests/llamafile/typed_prompt_tools.rs +++ b/rig/rig-core/tests/llamafile/typed_prompt_tools.rs @@ -96,7 +96,7 @@ async fn prompt_typed_with_tool_call_roundtrip() -> Result<()> { .prompt_typed("Hello, whats the weather in London?") .await?; - assert!( + anyhow::ensure!( call_count.load(Ordering::SeqCst) >= 1, "expected the weather tool to be executed at least once" ); diff --git a/rig/rig-core/tests/mistral/extractor_usage.rs b/rig/rig-core/tests/mistral/extractor_usage.rs index e083778a3..e5e02c4e2 100644 --- a/rig/rig-core/tests/mistral/extractor_usage.rs +++ b/rig/rig-core/tests/mistral/extractor_usage.rs @@ -25,17 +25,18 @@ struct Address { zip_code: Option, } -fn assert_compatible_professions(left: Option<&str>, right: &str) { +fn assert_compatible_professions(left: Option<&str>, right: &str) -> Result<()> { let left = left - .expect("profession should be present") + .ok_or_else(|| anyhow::anyhow!("profession should be present"))? .trim() .to_ascii_lowercase(); let right = right.trim().to_ascii_lowercase(); - assert!( + anyhow::ensure!( left == right || left.contains(&right) || right.contains(&left), "expected compatible professions, got {left:?} and {right:?}" ); + Ok(()) } #[tokio::test] @@ -48,9 +49,9 @@ async fn extract_backward_compatibility() -> Result<()> { .extract("John Doe is a 30 year old software engineer.") .await?; - assert_eq!(person.name, Some("John Doe".to_string())); - assert_eq!(person.age, Some(30)); - assert_compatible_professions(person.profession.as_deref(), "software engineer"); + anyhow::ensure!(person.name.as_deref() == Some("John Doe")); + anyhow::ensure!(person.age == Some(30)); + assert_compatible_professions(person.profession.as_deref(), "software engineer")?; Ok(()) } @@ -65,12 +66,12 @@ async fn extract_with_usage_returns_data_and_usage() -> Result<()> { .extract_with_usage("Jane Smith is a 45 year old data scientist.") .await?; - assert_eq!(response.data.name, Some("Jane Smith".to_string())); - assert_eq!(response.data.age, Some(45)); - assert_compatible_professions(response.data.profession.as_deref(), "data scientist"); - assert!(response.usage.input_tokens > 0); - assert!(response.usage.output_tokens > 0); - assert!(response.usage.total_tokens > 0); + anyhow::ensure!(response.data.name.as_deref() == Some("Jane Smith")); + anyhow::ensure!(response.data.age == Some(45)); + assert_compatible_professions(response.data.profession.as_deref(), "data scientist")?; + anyhow::ensure!(response.usage.input_tokens > 0); + anyhow::ensure!(response.usage.output_tokens > 0); + anyhow::ensure!(response.usage.total_tokens > 0); Ok(()) } @@ -92,12 +93,12 @@ async fn extract_with_chat_history_with_usage_works() -> Result<()> { ) .await?; - assert_eq!(response.data.street, Some("123 Main St".to_string())); - assert_eq!(response.data.city, Some("Springfield".to_string())); - assert_eq!(response.data.state, Some("IL".to_string())); - assert_eq!(response.data.zip_code, Some("62701".to_string())); - assert!(response.usage.input_tokens > 0); - assert!(response.usage.total_tokens > 0); + anyhow::ensure!(response.data.street.as_deref() == Some("123 Main St")); + anyhow::ensure!(response.data.city.as_deref() == Some("Springfield")); + anyhow::ensure!(response.data.state.as_deref() == Some("IL")); + anyhow::ensure!(response.data.zip_code.as_deref() == Some("62701")); + anyhow::ensure!(response.usage.input_tokens > 0); + anyhow::ensure!(response.usage.total_tokens > 0); Ok(()) } @@ -112,13 +113,13 @@ async fn extract_and_extract_with_usage_return_same_data() -> Result<()> { let person = extractor.extract(text).await?; let response = extractor.extract_with_usage(text).await?; - assert_eq!(person.name, Some("Bob Johnson".to_string())); - assert_eq!(response.data.name, Some("Bob Johnson".to_string())); - assert_eq!(person.age, Some(55)); - assert_eq!(response.data.age, Some(55)); - assert_compatible_professions(person.profession.as_deref(), "retired teacher"); - assert_compatible_professions(response.data.profession.as_deref(), "retired teacher"); - assert!(response.usage.total_tokens > 0, "usage should be populated"); + anyhow::ensure!(person.name.as_deref() == Some("Bob Johnson")); + anyhow::ensure!(response.data.name.as_deref() == Some("Bob Johnson")); + anyhow::ensure!(person.age == Some(55)); + anyhow::ensure!(response.data.age == Some(55)); + assert_compatible_professions(person.profession.as_deref(), "retired teacher")?; + assert_compatible_professions(response.data.profession.as_deref(), "retired teacher")?; + anyhow::ensure!(response.usage.total_tokens > 0, "usage should be populated"); Ok(()) } @@ -132,13 +133,13 @@ async fn usage_tracking_works_for_different_schemas() -> Result<()> { let person_response = person_extractor .extract_with_usage("Alice is a 25 year old developer.") .await?; - assert!(person_response.usage.total_tokens > 0); + anyhow::ensure!(person_response.usage.total_tokens > 0); let address_extractor = client.extractor::
(DEFAULT_MODEL).build(); let address_response = address_extractor .extract_with_usage("456 Oak Avenue, Cambridge, MA 02139") .await?; - assert!(address_response.usage.total_tokens > 0); + anyhow::ensure!(address_response.usage.total_tokens > 0); Ok(()) } diff --git a/rig/rig-core/tests/mistral/multi_extract.rs b/rig/rig-core/tests/mistral/multi_extract.rs index 726837eab..7bc2bea5f 100644 --- a/rig/rig-core/tests/mistral/multi_extract.rs +++ b/rig/rig-core/tests/mistral/multi_extract.rs @@ -118,7 +118,7 @@ async fn batch_multi_extract_chain() -> Result<()> { ) .await?; - assert_eq!(responses.len(), 3); + anyhow::ensure!(responses.len() == 3); assert_contains_any( &responses[0].names, diff --git a/rig/rig-core/tests/mistral/permission_control.rs b/rig/rig-core/tests/mistral/permission_control.rs index a62f8b36d..bd03dbc83 100644 --- a/rig/rig-core/tests/mistral/permission_control.rs +++ b/rig/rig-core/tests/mistral/permission_control.rs @@ -180,8 +180,8 @@ async fn permission_control_prompt_example() -> Result<()> { .await?; let last = last_result.lock().expect("lock last_result").clone(); - assert_eq!(last.as_deref(), Some("hello world")); - assert_eq!(call_count.load(Ordering::SeqCst), 2); + anyhow::ensure!(last.as_deref() == Some("hello world")); + anyhow::ensure!(call_count.load(Ordering::SeqCst) == 2); Ok(()) } @@ -218,7 +218,7 @@ async fn permission_control_streaming_example() -> Result<()> { let final_response = stream_to_stdout(&mut stream).await?; let last = last_result.lock().expect("lock last_result").clone(); assert_nonempty_response(final_response.response()); - assert!( + anyhow::ensure!( final_response .response() .to_ascii_lowercase() @@ -226,8 +226,8 @@ async fn permission_control_streaming_example() -> Result<()> { "expected the streamed final response to mention the file content, got {:?}", final_response.response() ); - assert_eq!(last.as_deref(), Some("hello world")); - assert_eq!(call_count.load(Ordering::SeqCst), 2); + anyhow::ensure!(last.as_deref() == Some("hello world")); + anyhow::ensure!(call_count.load(Ordering::SeqCst) == 2); Ok(()) } diff --git a/rig/rig-core/tests/mistral/request_hook.rs b/rig/rig-core/tests/mistral/request_hook.rs index b829892f7..5cbaa06fe 100644 --- a/rig/rig-core/tests/mistral/request_hook.rs +++ b/rig/rig-core/tests/mistral/request_hook.rs @@ -90,8 +90,8 @@ async fn request_hook_records_prompt_and_response() -> Result<()> { .await?; assert_nonempty_response(&response); - assert_eq!(hook.prompt_calls.load(Ordering::SeqCst), 1); - assert_eq!(hook.response_calls.load(Ordering::SeqCst), 1); + anyhow::ensure!(hook.prompt_calls.load(Ordering::SeqCst) == 1); + anyhow::ensure!(hook.response_calls.load(Ordering::SeqCst) == 1); let seen_prompt = hook .seen_prompt @@ -104,12 +104,12 @@ async fn request_hook_records_prompt_and_response() -> Result<()> { .map_err(|_| anyhow!("response hook state unavailable"))? .clone(); - assert!( + anyhow::ensure!( seen_prompt .as_deref() .is_some_and(|prompt| prompt.contains("Entertain me!")) ); - assert!( + anyhow::ensure!( seen_response .as_deref() .is_some_and(|captured| !captured.is_empty()) diff --git a/rig/rig-core/tests/mistral/typed_prompt_tools.rs b/rig/rig-core/tests/mistral/typed_prompt_tools.rs index 50f2ed9ce..dec5444d2 100644 --- a/rig/rig-core/tests/mistral/typed_prompt_tools.rs +++ b/rig/rig-core/tests/mistral/typed_prompt_tools.rs @@ -95,7 +95,7 @@ async fn prompt_typed_with_tool_call_roundtrip() -> Result<()> { .prompt_typed("Hello, whats the weather in London?") .await?; - assert!( + anyhow::ensure!( call_count.load(Ordering::SeqCst) >= 1, "expected the weather tool to be executed at least once" ); diff --git a/rig/rig-core/tests/openai/extractor_usage.rs b/rig/rig-core/tests/openai/extractor_usage.rs index e96ab4f24..8a968b3af 100644 --- a/rig/rig-core/tests/openai/extractor_usage.rs +++ b/rig/rig-core/tests/openai/extractor_usage.rs @@ -28,20 +28,21 @@ struct Address { zip_code: Option, } -fn assert_compatible_professions(left: Option<&str>, right: Option<&str>) { +fn assert_compatible_professions(left: Option<&str>, right: Option<&str>) -> Result<()> { let left = left - .expect("profession should be present") + .ok_or_else(|| anyhow::anyhow!("profession should be present"))? .trim() .to_ascii_lowercase(); let right = right - .expect("profession should be present") + .ok_or_else(|| anyhow::anyhow!("profession should be present"))? .trim() .to_ascii_lowercase(); - assert!( + anyhow::ensure!( left == right || left.contains(&right) || right.contains(&left), "expected compatible professions, got {left:?} and {right:?}" ); + Ok(()) } /// Test backward compatibility: the original `extract()` method should still work @@ -58,9 +59,9 @@ async fn extract_backward_compatibility() -> Result<()> { .extract("John Doe is a 30 year old software engineer.") .await?; - assert_eq!(person.name, Some("John Doe".to_string())); - assert_eq!(person.age, Some(30)); - assert_eq!(person.profession, Some("software engineer".to_string())); + anyhow::ensure!(person.name.as_deref() == Some("John Doe")); + anyhow::ensure!(person.age == Some(30)); + anyhow::ensure!(person.profession.as_deref() == Some("software engineer")); Ok(()) } @@ -79,14 +80,14 @@ async fn extract_with_usage_returns_data_and_usage() -> Result<()> { .await?; // Verify extracted data - assert_eq!(response.data.name, Some("Jane Smith".to_string())); - assert_eq!(response.data.age, Some(45)); - assert_eq!(response.data.profession, Some("data scientist".to_string())); + anyhow::ensure!(response.data.name.as_deref() == Some("Jane Smith")); + anyhow::ensure!(response.data.age == Some(45)); + anyhow::ensure!(response.data.profession.as_deref() == Some("data scientist")); // Verify usage is non-zero (we made at least one API call) - assert!(response.usage.input_tokens > 0); - assert!(response.usage.output_tokens > 0); - assert!(response.usage.total_tokens > 0); + anyhow::ensure!(response.usage.input_tokens > 0); + anyhow::ensure!(response.usage.output_tokens > 0); + anyhow::ensure!(response.usage.total_tokens > 0); Ok(()) } @@ -114,14 +115,14 @@ async fn extract_with_chat_history_with_usage_works() -> Result<()> { .await?; // Verify extracted data - assert_eq!(response.data.street, Some("123 Main St".to_string())); - assert_eq!(response.data.city, Some("Springfield".to_string())); - assert_eq!(response.data.state, Some("IL".to_string())); - assert_eq!(response.data.zip_code, Some("62701".to_string())); + anyhow::ensure!(response.data.street.as_deref() == Some("123 Main St")); + anyhow::ensure!(response.data.city.as_deref() == Some("Springfield")); + anyhow::ensure!(response.data.state.as_deref() == Some("IL")); + anyhow::ensure!(response.data.zip_code.as_deref() == Some("62701")); // Verify usage is non-zero - assert!(response.usage.input_tokens > 0); - assert!(response.usage.total_tokens > 0); + anyhow::ensure!(response.usage.input_tokens > 0); + anyhow::ensure!(response.usage.total_tokens > 0); Ok(()) } @@ -144,15 +145,15 @@ async fn extract_and_extract_with_usage_return_same_data() -> Result<()> { // Extract with usage let response = extractor.extract_with_usage(text).await?; - assert_eq!(person.name, Some("Bob Johnson".to_string())); - assert_eq!(response.data.name, Some("Bob Johnson".to_string())); - assert_eq!(person.age, Some(55)); - assert_eq!(response.data.age, Some(55)); + anyhow::ensure!(person.name.as_deref() == Some("Bob Johnson")); + anyhow::ensure!(response.data.name.as_deref() == Some("Bob Johnson")); + anyhow::ensure!(person.age == Some(55)); + anyhow::ensure!(response.data.age == Some(55)); assert_compatible_professions( person.profession.as_deref(), response.data.profession.as_deref(), - ); - assert!(response.usage.total_tokens > 0, "usage should be populated"); + )?; + anyhow::ensure!(response.usage.total_tokens > 0, "usage should be populated"); Ok(()) } @@ -172,7 +173,7 @@ async fn usage_tracking_works_for_different_schemas() -> Result<()> { .extract_with_usage("Alice is a 25 year old developer.") .await?; - assert!(person_response.usage.total_tokens > 0); + anyhow::ensure!(person_response.usage.total_tokens > 0); // Test with more complex schema let address_extractor = client @@ -183,7 +184,7 @@ async fn usage_tracking_works_for_different_schemas() -> Result<()> { .extract_with_usage("456 Oak Avenue, Cambridge, MA 02139") .await?; - assert!(address_response.usage.total_tokens > 0); + anyhow::ensure!(address_response.usage.total_tokens > 0); Ok(()) } diff --git a/rig/rig-core/tests/openai/multi_extract.rs b/rig/rig-core/tests/openai/multi_extract.rs index 6796159db..3a86d873c 100644 --- a/rig/rig-core/tests/openai/multi_extract.rs +++ b/rig/rig-core/tests/openai/multi_extract.rs @@ -73,7 +73,7 @@ async fn batch_multi_extract_chain() -> Result<()> { ) .await?; - assert_eq!(responses.len(), 3); + anyhow::ensure!(responses.len() == 3); for response in responses { assert_nonempty_response(&response); } diff --git a/rig/rig-core/tests/openai/permission_control.rs b/rig/rig-core/tests/openai/permission_control.rs index 8b2f1430e..f497f65b6 100644 --- a/rig/rig-core/tests/openai/permission_control.rs +++ b/rig/rig-core/tests/openai/permission_control.rs @@ -176,8 +176,8 @@ async fn permission_control_prompt_example() -> Result<()> { .await?; let last = last_result.lock().expect("lock last_result").clone(); - assert_eq!(last.as_deref(), Some("hello world")); - assert_eq!(call_count.load(Ordering::SeqCst), 2); + anyhow::ensure!(last.as_deref() == Some("hello world")); + anyhow::ensure!(call_count.load(Ordering::SeqCst) == 2); Ok(()) } @@ -213,7 +213,7 @@ async fn permission_control_streaming_example() -> Result<()> { let final_response = stream_to_stdout(&mut stream).await?; let last = last_result.lock().expect("lock last_result").clone(); assert_nonempty_response(final_response.response()); - assert!( + anyhow::ensure!( final_response .response() .to_ascii_lowercase() @@ -221,8 +221,8 @@ async fn permission_control_streaming_example() -> Result<()> { "expected the streamed final response to mention the file content, got {:?}", final_response.response() ); - assert_eq!(last.as_deref(), Some("hello world")); - assert_eq!(call_count.load(Ordering::SeqCst), 2); + anyhow::ensure!(last.as_deref() == Some("hello world")); + anyhow::ensure!(call_count.load(Ordering::SeqCst) == 2); Ok(()) } diff --git a/rig/rig-core/tests/openai/request_hook.rs b/rig/rig-core/tests/openai/request_hook.rs index e7ffd6c7c..b8834cd37 100644 --- a/rig/rig-core/tests/openai/request_hook.rs +++ b/rig/rig-core/tests/openai/request_hook.rs @@ -88,8 +88,8 @@ async fn request_hook_records_prompt_and_response() -> Result<()> { .await?; assert_nonempty_response(&response); - assert_eq!(hook.prompt_calls.load(Ordering::SeqCst), 1); - assert_eq!(hook.response_calls.load(Ordering::SeqCst), 1); + anyhow::ensure!(hook.prompt_calls.load(Ordering::SeqCst) == 1); + anyhow::ensure!(hook.response_calls.load(Ordering::SeqCst) == 1); let seen_prompt = hook .seen_prompt @@ -102,12 +102,12 @@ async fn request_hook_records_prompt_and_response() -> Result<()> { .map_err(|_| anyhow!("response hook state unavailable"))? .clone(); - assert!( + anyhow::ensure!( seen_prompt .as_deref() .is_some_and(|prompt| prompt.contains("Entertain me!")) ); - assert!( + anyhow::ensure!( seen_response .as_deref() .is_some_and(|captured| !captured.is_empty()) diff --git a/rig/rig-core/tests/openai/typed_prompt_tools.rs b/rig/rig-core/tests/openai/typed_prompt_tools.rs index 64d46d38e..794f4bb94 100644 --- a/rig/rig-core/tests/openai/typed_prompt_tools.rs +++ b/rig/rig-core/tests/openai/typed_prompt_tools.rs @@ -102,7 +102,7 @@ async fn prompt_typed_with_tool_call_roundtrip() -> Result<()> { let response: WeatherResponse = result?; println!("agent response: {response:#?}"); - assert!( + anyhow::ensure!( call_count.load(Ordering::SeqCst) >= 1, "expected the weather tool to be executed at least once" ); diff --git a/rig/rig-core/tests/openai/websocket.rs b/rig/rig-core/tests/openai/websocket.rs index ac6f263f2..4dfb3a08f 100644 --- a/rig/rig-core/tests/openai/websocket.rs +++ b/rig/rig-core/tests/openai/websocket.rs @@ -34,7 +34,7 @@ async fn websocket_session_roundtrip() -> Result<()> { .preamble("Be precise and concise.".to_string()) .build(); let warmup_id = session.warmup(warmup_request).await?; - assert!(!warmup_id.is_empty(), "warmup should return a response id"); + anyhow::ensure!(!warmup_id.is_empty(), "warmup should return a response id"); let request = model .completion_request("Explain the benefit of websocket mode in one sentence.") diff --git a/rig/rig-core/tests/openrouter/extractor_usage.rs b/rig/rig-core/tests/openrouter/extractor_usage.rs index 2bd3bdd1e..ec469a2bf 100644 --- a/rig/rig-core/tests/openrouter/extractor_usage.rs +++ b/rig/rig-core/tests/openrouter/extractor_usage.rs @@ -25,17 +25,18 @@ struct Address { zip_code: Option, } -fn assert_compatible_professions(left: Option<&str>, right: &str) { +fn assert_compatible_professions(left: Option<&str>, right: &str) -> Result<()> { let left = left - .expect("profession should be present") + .ok_or_else(|| anyhow::anyhow!("profession should be present"))? .trim() .to_ascii_lowercase(); let right = right.trim().to_ascii_lowercase(); - assert!( + anyhow::ensure!( left == right || left.contains(&right) || right.contains(&left), "expected compatible professions, got {left:?} and {right:?}" ); + Ok(()) } #[tokio::test] @@ -48,9 +49,9 @@ async fn extract_backward_compatibility() -> Result<()> { .extract("John Doe is a 30 year old software engineer.") .await?; - assert_eq!(person.name, Some("John Doe".to_string())); - assert_eq!(person.age, Some(30)); - assert_compatible_professions(person.profession.as_deref(), "software engineer"); + anyhow::ensure!(person.name.as_deref() == Some("John Doe")); + anyhow::ensure!(person.age == Some(30)); + assert_compatible_professions(person.profession.as_deref(), "software engineer")?; Ok(()) } @@ -65,12 +66,12 @@ async fn extract_with_usage_returns_data_and_usage() -> Result<()> { .extract_with_usage("Jane Smith is a 45 year old data scientist.") .await?; - assert_eq!(response.data.name, Some("Jane Smith".to_string())); - assert_eq!(response.data.age, Some(45)); - assert_compatible_professions(response.data.profession.as_deref(), "data scientist"); - assert!(response.usage.input_tokens > 0); - assert!(response.usage.output_tokens > 0); - assert!(response.usage.total_tokens > 0); + anyhow::ensure!(response.data.name.as_deref() == Some("Jane Smith")); + anyhow::ensure!(response.data.age == Some(45)); + assert_compatible_professions(response.data.profession.as_deref(), "data scientist")?; + anyhow::ensure!(response.usage.input_tokens > 0); + anyhow::ensure!(response.usage.output_tokens > 0); + anyhow::ensure!(response.usage.total_tokens > 0); Ok(()) } @@ -92,12 +93,12 @@ async fn extract_with_chat_history_with_usage_works() -> Result<()> { ) .await?; - assert_eq!(response.data.street, Some("123 Main St".to_string())); - assert_eq!(response.data.city, Some("Springfield".to_string())); - assert_eq!(response.data.state, Some("IL".to_string())); - assert_eq!(response.data.zip_code, Some("62701".to_string())); - assert!(response.usage.input_tokens > 0); - assert!(response.usage.total_tokens > 0); + anyhow::ensure!(response.data.street.as_deref() == Some("123 Main St")); + anyhow::ensure!(response.data.city.as_deref() == Some("Springfield")); + anyhow::ensure!(response.data.state.as_deref() == Some("IL")); + anyhow::ensure!(response.data.zip_code.as_deref() == Some("62701")); + anyhow::ensure!(response.usage.input_tokens > 0); + anyhow::ensure!(response.usage.total_tokens > 0); Ok(()) } @@ -112,13 +113,13 @@ async fn extract_and_extract_with_usage_return_same_data() -> Result<()> { let person = extractor.extract(text).await?; let response = extractor.extract_with_usage(text).await?; - assert_eq!(person.name, Some("Bob Johnson".to_string())); - assert_eq!(response.data.name, Some("Bob Johnson".to_string())); - assert_eq!(person.age, Some(55)); - assert_eq!(response.data.age, Some(55)); - assert_compatible_professions(person.profession.as_deref(), "retired teacher"); - assert_compatible_professions(response.data.profession.as_deref(), "retired teacher"); - assert!(response.usage.total_tokens > 0, "usage should be populated"); + anyhow::ensure!(person.name.as_deref() == Some("Bob Johnson")); + anyhow::ensure!(response.data.name.as_deref() == Some("Bob Johnson")); + anyhow::ensure!(person.age == Some(55)); + anyhow::ensure!(response.data.age == Some(55)); + assert_compatible_professions(person.profession.as_deref(), "retired teacher")?; + assert_compatible_professions(response.data.profession.as_deref(), "retired teacher")?; + anyhow::ensure!(response.usage.total_tokens > 0, "usage should be populated"); Ok(()) } @@ -132,13 +133,13 @@ async fn usage_tracking_works_for_different_schemas() -> Result<()> { let person_response = person_extractor .extract_with_usage("Alice is a 25 year old developer.") .await?; - assert!(person_response.usage.total_tokens > 0); + anyhow::ensure!(person_response.usage.total_tokens > 0); let address_extractor = client.extractor::
(DEFAULT_MODEL).build(); let address_response = address_extractor .extract_with_usage("456 Oak Avenue, Cambridge, MA 02139") .await?; - assert!(address_response.usage.total_tokens > 0); + anyhow::ensure!(address_response.usage.total_tokens > 0); Ok(()) } diff --git a/rig/rig-core/tests/openrouter/multi_extract.rs b/rig/rig-core/tests/openrouter/multi_extract.rs index 671247b8e..a7df40690 100644 --- a/rig/rig-core/tests/openrouter/multi_extract.rs +++ b/rig/rig-core/tests/openrouter/multi_extract.rs @@ -75,7 +75,7 @@ async fn batch_multi_extract_chain() -> Result<()> { ) .await?; - assert_eq!(responses.len(), 3); + anyhow::ensure!(responses.len() == 3); for response in responses { assert_nonempty_response(&response); } diff --git a/rig/rig-core/tests/openrouter/permission_control.rs b/rig/rig-core/tests/openrouter/permission_control.rs index d47ad91ca..55eaad5c5 100644 --- a/rig/rig-core/tests/openrouter/permission_control.rs +++ b/rig/rig-core/tests/openrouter/permission_control.rs @@ -180,8 +180,8 @@ async fn permission_control_prompt_example() -> Result<()> { .await?; let last = last_result.lock().expect("lock last_result").clone(); - assert_eq!(last.as_deref(), Some("hello world")); - assert_eq!(call_count.load(Ordering::SeqCst), 2); + anyhow::ensure!(last.as_deref() == Some("hello world")); + anyhow::ensure!(call_count.load(Ordering::SeqCst) == 2); Ok(()) } @@ -218,7 +218,7 @@ async fn permission_control_streaming_example() -> Result<()> { let final_response = stream_to_stdout(&mut stream).await?; let last = last_result.lock().expect("lock last_result").clone(); assert_nonempty_response(final_response.response()); - assert!( + anyhow::ensure!( final_response .response() .to_ascii_lowercase() @@ -226,8 +226,8 @@ async fn permission_control_streaming_example() -> Result<()> { "expected the streamed final response to mention the file content, got {:?}", final_response.response() ); - assert_eq!(last.as_deref(), Some("hello world")); - assert_eq!(call_count.load(Ordering::SeqCst), 2); + anyhow::ensure!(last.as_deref() == Some("hello world")); + anyhow::ensure!(call_count.load(Ordering::SeqCst) == 2); Ok(()) } diff --git a/rig/rig-core/tests/openrouter/request_hook.rs b/rig/rig-core/tests/openrouter/request_hook.rs index d33df35ee..4a03a3f7b 100644 --- a/rig/rig-core/tests/openrouter/request_hook.rs +++ b/rig/rig-core/tests/openrouter/request_hook.rs @@ -90,8 +90,8 @@ async fn request_hook_records_prompt_and_response() -> Result<()> { .await?; assert_nonempty_response(&response); - assert_eq!(hook.prompt_calls.load(Ordering::SeqCst), 1); - assert_eq!(hook.response_calls.load(Ordering::SeqCst), 1); + anyhow::ensure!(hook.prompt_calls.load(Ordering::SeqCst) == 1); + anyhow::ensure!(hook.response_calls.load(Ordering::SeqCst) == 1); let seen_prompt = hook .seen_prompt @@ -104,12 +104,12 @@ async fn request_hook_records_prompt_and_response() -> Result<()> { .map_err(|_| anyhow!("response hook state unavailable"))? .clone(); - assert!( + anyhow::ensure!( seen_prompt .as_deref() .is_some_and(|prompt| prompt.contains("Entertain me!")) ); - assert!( + anyhow::ensure!( seen_response .as_deref() .is_some_and(|captured| !captured.is_empty()) diff --git a/rig/rig-core/tests/openrouter/typed_prompt_tools.rs b/rig/rig-core/tests/openrouter/typed_prompt_tools.rs index 2a57d71e9..154712a73 100644 --- a/rig/rig-core/tests/openrouter/typed_prompt_tools.rs +++ b/rig/rig-core/tests/openrouter/typed_prompt_tools.rs @@ -94,7 +94,7 @@ async fn prompt_typed_with_tool_call_roundtrip() -> Result<()> { .prompt_typed("Hello, whats the weather in London?") .await?; - assert!( + anyhow::ensure!( call_count.load(Ordering::SeqCst) >= 1, "expected the weather tool to be executed at least once" ); diff --git a/rig/rig-core/tests/xai/extractor_usage.rs b/rig/rig-core/tests/xai/extractor_usage.rs index c7ef02180..f598105ee 100644 --- a/rig/rig-core/tests/xai/extractor_usage.rs +++ b/rig/rig-core/tests/xai/extractor_usage.rs @@ -23,17 +23,18 @@ struct Address { zip_code: Option, } -fn assert_compatible_professions(left: Option<&str>, right: &str) { +fn assert_compatible_professions(left: Option<&str>, right: &str) -> Result<()> { let left = left - .expect("profession should be present") + .ok_or_else(|| anyhow::anyhow!("profession should be present"))? .trim() .to_ascii_lowercase(); let right = right.trim().to_ascii_lowercase(); - assert!( + anyhow::ensure!( left == right || left.contains(&right) || right.contains(&left), "expected compatible professions, got {left:?} and {right:?}" ); + Ok(()) } #[tokio::test] @@ -46,9 +47,9 @@ async fn extract_backward_compatibility() -> Result<()> { .extract("John Doe is a 30 year old software engineer.") .await?; - assert_eq!(person.name, Some("John Doe".to_string())); - assert_eq!(person.age, Some(30)); - assert_compatible_professions(person.profession.as_deref(), "software engineer"); + anyhow::ensure!(person.name.as_deref() == Some("John Doe")); + anyhow::ensure!(person.age == Some(30)); + assert_compatible_professions(person.profession.as_deref(), "software engineer")?; Ok(()) } @@ -63,12 +64,12 @@ async fn extract_with_usage_returns_data_and_usage() -> Result<()> { .extract_with_usage("Jane Smith is a 45 year old data scientist.") .await?; - assert_eq!(response.data.name, Some("Jane Smith".to_string())); - assert_eq!(response.data.age, Some(45)); - assert_compatible_professions(response.data.profession.as_deref(), "data scientist"); - assert!(response.usage.input_tokens > 0); - assert!(response.usage.output_tokens > 0); - assert!(response.usage.total_tokens > 0); + anyhow::ensure!(response.data.name.as_deref() == Some("Jane Smith")); + anyhow::ensure!(response.data.age == Some(45)); + assert_compatible_professions(response.data.profession.as_deref(), "data scientist")?; + anyhow::ensure!(response.usage.input_tokens > 0); + anyhow::ensure!(response.usage.output_tokens > 0); + anyhow::ensure!(response.usage.total_tokens > 0); Ok(()) } @@ -90,12 +91,12 @@ async fn extract_with_chat_history_with_usage_works() -> Result<()> { ) .await?; - assert_eq!(response.data.street, Some("123 Main St".to_string())); - assert_eq!(response.data.city, Some("Springfield".to_string())); - assert_eq!(response.data.state, Some("IL".to_string())); - assert_eq!(response.data.zip_code, Some("62701".to_string())); - assert!(response.usage.input_tokens > 0); - assert!(response.usage.total_tokens > 0); + anyhow::ensure!(response.data.street.as_deref() == Some("123 Main St")); + anyhow::ensure!(response.data.city.as_deref() == Some("Springfield")); + anyhow::ensure!(response.data.state.as_deref() == Some("IL")); + anyhow::ensure!(response.data.zip_code.as_deref() == Some("62701")); + anyhow::ensure!(response.usage.input_tokens > 0); + anyhow::ensure!(response.usage.total_tokens > 0); Ok(()) } @@ -110,13 +111,13 @@ async fn extract_and_extract_with_usage_return_same_data() -> Result<()> { let person = extractor.extract(text).await?; let response = extractor.extract_with_usage(text).await?; - assert_eq!(person.name, Some("Bob Johnson".to_string())); - assert_eq!(response.data.name, Some("Bob Johnson".to_string())); - assert_eq!(person.age, Some(55)); - assert_eq!(response.data.age, Some(55)); - assert_compatible_professions(person.profession.as_deref(), "retired teacher"); - assert_compatible_professions(response.data.profession.as_deref(), "retired teacher"); - assert!(response.usage.total_tokens > 0, "usage should be populated"); + anyhow::ensure!(person.name.as_deref() == Some("Bob Johnson")); + anyhow::ensure!(response.data.name.as_deref() == Some("Bob Johnson")); + anyhow::ensure!(person.age == Some(55)); + anyhow::ensure!(response.data.age == Some(55)); + assert_compatible_professions(person.profession.as_deref(), "retired teacher")?; + assert_compatible_professions(response.data.profession.as_deref(), "retired teacher")?; + anyhow::ensure!(response.usage.total_tokens > 0, "usage should be populated"); Ok(()) } @@ -130,13 +131,13 @@ async fn usage_tracking_works_for_different_schemas() -> Result<()> { let person_response = person_extractor .extract_with_usage("Alice is a 25 year old developer.") .await?; - assert!(person_response.usage.total_tokens > 0); + anyhow::ensure!(person_response.usage.total_tokens > 0); let address_extractor = client.extractor::
(xai::GROK_3_MINI).build(); let address_response = address_extractor .extract_with_usage("456 Oak Avenue, Cambridge, MA 02139") .await?; - assert!(address_response.usage.total_tokens > 0); + anyhow::ensure!(address_response.usage.total_tokens > 0); Ok(()) } diff --git a/rig/rig-core/tests/xai/multi_extract.rs b/rig/rig-core/tests/xai/multi_extract.rs index a35e4fb53..b0be28f9b 100644 --- a/rig/rig-core/tests/xai/multi_extract.rs +++ b/rig/rig-core/tests/xai/multi_extract.rs @@ -116,7 +116,7 @@ async fn batch_multi_extract_chain() -> Result<()> { ) .await?; - assert_eq!(responses.len(), 3); + anyhow::ensure!(responses.len() == 3); assert_contains_any( &responses[0].names, diff --git a/rig/rig-core/tests/xai/permission_control.rs b/rig/rig-core/tests/xai/permission_control.rs index 626159d59..c9dfb13c2 100644 --- a/rig/rig-core/tests/xai/permission_control.rs +++ b/rig/rig-core/tests/xai/permission_control.rs @@ -178,8 +178,8 @@ async fn permission_control_prompt_example() -> Result<()> { .await?; let last = last_result.lock().expect("lock last_result").clone(); - assert_eq!(last.as_deref(), Some("hello world")); - assert_eq!(call_count.load(Ordering::SeqCst), 2); + anyhow::ensure!(last.as_deref() == Some("hello world")); + anyhow::ensure!(call_count.load(Ordering::SeqCst) == 2); Ok(()) } @@ -216,7 +216,7 @@ async fn permission_control_streaming_example() -> Result<()> { let final_response = stream_to_stdout(&mut stream).await?; let last = last_result.lock().expect("lock last_result").clone(); assert_nonempty_response(final_response.response()); - assert!( + anyhow::ensure!( final_response .response() .to_ascii_lowercase() @@ -224,8 +224,8 @@ async fn permission_control_streaming_example() -> Result<()> { "expected the streamed final response to mention the file content, got {:?}", final_response.response() ); - assert_eq!(last.as_deref(), Some("hello world")); - assert_eq!(call_count.load(Ordering::SeqCst), 2); + anyhow::ensure!(last.as_deref() == Some("hello world")); + anyhow::ensure!(call_count.load(Ordering::SeqCst) == 2); Ok(()) } diff --git a/rig/rig-core/tests/xai/request_hook.rs b/rig/rig-core/tests/xai/request_hook.rs index 93e6a71db..bf437bd0a 100644 --- a/rig/rig-core/tests/xai/request_hook.rs +++ b/rig/rig-core/tests/xai/request_hook.rs @@ -88,8 +88,8 @@ async fn request_hook_records_prompt_and_response() -> Result<()> { .await?; assert_nonempty_response(&response); - assert_eq!(hook.prompt_calls.load(Ordering::SeqCst), 1); - assert_eq!(hook.response_calls.load(Ordering::SeqCst), 1); + anyhow::ensure!(hook.prompt_calls.load(Ordering::SeqCst) == 1); + anyhow::ensure!(hook.response_calls.load(Ordering::SeqCst) == 1); let seen_prompt = hook .seen_prompt @@ -102,12 +102,12 @@ async fn request_hook_records_prompt_and_response() -> Result<()> { .map_err(|_| anyhow!("response hook state unavailable"))? .clone(); - assert!( + anyhow::ensure!( seen_prompt .as_deref() .is_some_and(|prompt| prompt.contains("Entertain me!")) ); - assert!( + anyhow::ensure!( seen_response .as_deref() .is_some_and(|captured| !captured.is_empty()) diff --git a/rig/rig-core/tests/xai/typed_prompt_tools.rs b/rig/rig-core/tests/xai/typed_prompt_tools.rs index f2f3cc8b5..2e91a1f84 100644 --- a/rig/rig-core/tests/xai/typed_prompt_tools.rs +++ b/rig/rig-core/tests/xai/typed_prompt_tools.rs @@ -92,7 +92,7 @@ async fn prompt_typed_with_tool_call_roundtrip() -> Result<()> { .prompt_typed("Hello, whats the weather in London?") .await?; - assert!( + anyhow::ensure!( call_count.load(Ordering::SeqCst) >= 1, "expected the weather tool to be executed at least once" ); From a93b55a7daf2551c0ae56e254caf0a9e79ce2153 Mon Sep 17 00:00:00 2001 From: stephen Date: Thu, 23 Apr 2026 19:33:04 -0700 Subject: [PATCH 5/5] Update lib.rs --- rig-integrations/rig-lancedb/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rig-integrations/rig-lancedb/src/lib.rs b/rig-integrations/rig-lancedb/src/lib.rs index b1521a03b..debca7233 100644 --- a/rig-integrations/rig-lancedb/src/lib.rs +++ b/rig-integrations/rig-lancedb/src/lib.rs @@ -390,7 +390,7 @@ where /// let result = vector_store_index /// .top_n::("My boss says I zindle too much, what does that mean?", 1) /// .await?; - /// ```ignore + /// ``` async fn top_n Deserialize<'a> + Send>( &self, req: VectorSearchRequest,