Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions src/wags_llm/client/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,21 @@

import json
import logging
from typing import Any
from typing import Any, Literal

import boto3

from wags_llm.client.base import InvokeJsonResponse, LLMJsonClient
from wags_llm.client.exceptions import (
LLMEmptyResponseError,
LLMInvalidEffortError,
LLMInvocationError,
LLMJsonDecodeError,
LLMResponseFormatError,
)

_logger = logging.getLogger(__name__)
_VALID_EFFORT_LEVELS = ("high", "medium", "low")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use an enum here?



class BedrockClaudeJsonClient(LLMJsonClient):
Expand All @@ -31,6 +33,7 @@ def __init__(
profile_name: str,
max_tokens: int = 300,
temperature: float = 0.0,
effort: Literal["high", "medium", "low"] | None = None,
) -> None:
"""Initialize the Bedrock Claude client.

Expand All @@ -39,18 +42,26 @@ def __init__(
:param profile_name: AWS profile name.
:param max_tokens: Maximum number of tokens to request from the model.
:param temperature: Sampling temperature.
:param effort: Optional adaptive thinking effort level for supported Claude models using Bedrock Converse: "high", "medium", "low", or None to use the model default.
:raise LLMInvalidEffortError: If effort is not "high", "medium", "low", or None.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, this was a claude-specific parameter. I think the custom error should live in this module. If I'm wrong, then we can leave it where its at

"""
if effort is not None and effort not in _VALID_EFFORT_LEVELS:
msg = f"Invalid effort '{effort}'; must be one of 'high', 'medium', 'low', or None."
raise LLMInvalidEffortError(msg)

_logger.debug(
"BedrockClaudeJsonClient config: model_id='%s', region_name='%s', profile_name='%s', max_tokens=%i, temperature=%f",
"BedrockClaudeJsonClient config: model_id='%s', region_name='%s', profile_name='%s', max_tokens=%i, temperature=%f, effort=%s",
model_id,
region_name,
profile_name,
max_tokens,
temperature,
effort,
)
self.model_id = model_id
self.max_tokens = max_tokens
self.temperature = temperature
self.effort = effort

session = boto3.Session(profile_name=profile_name)
self._client = session.client("bedrock-runtime", region_name=region_name)
Expand Down Expand Up @@ -98,6 +109,12 @@ def invoke_json(
},
}

adaptive_thinking_params: dict[str, Any] = {"thinking": {"type": "adaptive"}}
if self.effort:
adaptive_thinking_params["output_config"] = {"effort": self.effort}

converse_params["additionalModelRequestFields"] = adaptive_thinking_params

Comment on lines +112 to +117

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should additionalModelRequestFields only be added if effort is provided?

if json_schema:
converse_params["outputConfig"] = {
"textFormat": {
Expand Down
4 changes: 4 additions & 0 deletions src/wags_llm/client/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,7 @@ class LLMEmptyResponseError(LLMClientError):

class LLMJsonDecodeError(LLMClientError):
"""Raised when the model output is not valid JSON."""


class LLMInvalidEffortError(LLMClientError):
"""Raised when the effort parameter is not a valid value."""
78 changes: 77 additions & 1 deletion tests/integration/client/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from wags_llm.client.bedrock import (
BedrockClaudeJsonClient,
LLMEmptyResponseError,
LLMInvalidEffortError,
LLMInvocationError,
LLMJsonDecodeError,
LLMResponseFormatError,
Expand All @@ -31,14 +32,16 @@ def __init__(self, response=None, error=None):
"""
self.response = response
self.error = error
self.captured_request = None

def converse(self, **kwargs): # noqa: ARG002
def converse(self, **kwargs):
"""Return a fake converse response.

:param kwargs: Converse request arguments.
:return: Fake response payload.
:raise Exception: If configured with an error.
"""
self.captured_request = kwargs
if self.error is not None:
raise self.error
return self.response
Expand Down Expand Up @@ -66,6 +69,79 @@ def client(self, service_name: str, region_name: str):
return self.runtime_client


def test_invoke_json_with_effort():
"""Test that invoke_json includes the effort beta config when effort is set."""
fake_runtime_client = FakeBedrockRuntimeClient(
response={
"output": {
"message": {
"content": [
{"text": '{"value": 1}'},
]
}
}
}
)

with patch(
"wags_llm.client.bedrock.boto3.Session",
return_value=FakeSession(fake_runtime_client),
):
client = BedrockClaudeJsonClient(
model_id=TEST_MODEL_ID,
region_name=TEST_REGION_NAME,
profile_name=TEST_PROFILE_NAME,
effort="medium",
)

client.invoke_json(
system_prompt=TEST_SYSTEM_PROMPT,
user_prompt=TEST_USER_PROMPT,
)

assert fake_runtime_client.captured_request["additionalModelRequestFields"] == {
"thinking": {"type": "adaptive"},
"output_config": {"effort": "medium"},
}


def test_invoke_json_without_effort_omits_field():
"""Test that invoke_json omits additionalModelRequestFields when effort is unset."""
fake_runtime_client = FakeBedrockRuntimeClient(
response={"output": {"message": {"content": [{"text": '{"value": 1}'}]}}}
)

with patch(
"wags_llm.client.bedrock.boto3.Session",
return_value=FakeSession(fake_runtime_client),
):
client = BedrockClaudeJsonClient(
model_id=TEST_MODEL_ID,
region_name=TEST_REGION_NAME,
profile_name=TEST_PROFILE_NAME,
)

client.invoke_json(
system_prompt=TEST_SYSTEM_PROMPT,
user_prompt=TEST_USER_PROMPT,
)

assert fake_runtime_client.captured_request["additionalModelRequestFields"] == {
"thinking": {"type": "adaptive"},
}


def test_invalid_effort_raises():
"""Test that an invalid effort value raises LLMInvalidEffortError at construction."""
with pytest.raises(LLMInvalidEffortError, match=r"Invalid effort"):
BedrockClaudeJsonClient(
model_id=TEST_MODEL_ID,
region_name=TEST_REGION_NAME,
profile_name=TEST_PROFILE_NAME,
effort="extreme",
)


def test_invoke_json_success():
"""Test that invoke_json works correctly"""
fake_runtime_client = FakeBedrockRuntimeClient(
Expand Down
Loading