Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from data_designer.engine.models.clients.throttled import ThrottledModelClient
from data_designer.engine.models.clients.types import (
AssistantMessage,
ChatCompletionChoice,
ChatCompletionRequest,
ChatCompletionResponse,
EmbeddingRequest,
Expand All @@ -31,6 +32,7 @@

__all__ = [
"AssistantMessage",
"ChatCompletionChoice",
"ChatCompletionRequest",
"ChatCompletionResponse",
"EmbeddingRequest",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class AnthropicClient(HttpModelClient):
"stop",
"max_tokens",
"tools",
"n",
"response_format",
"frequency_penalty",
"presence_penalty",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from data_designer.engine.models.clients.types import (
AssistantMessage,
ChatCompletionChoice,
ChatCompletionResponse,
ImagePayload,
ToolCall,
Expand All @@ -36,35 +37,75 @@


def parse_chat_completion_response(response: Any) -> ChatCompletionResponse:
first_choice = get_first_value_or_none(get_value_from(response, "choices"))
message = get_value_from(first_choice, "message")
tool_calls = extract_tool_calls(get_value_from(message, "tool_calls"))
images = extract_images_from_chat_message(message)
assistant_message = AssistantMessage(
content=coerce_message_content(get_value_from(message, "content")),
reasoning_content=extract_reasoning_content(message),
tool_calls=tool_calls,
images=images,
choices = [
parse_chat_completion_choice(choice) for choice in normalize_choice_list(get_value_from(response, "choices"))
]
assistant_message = choices[0].message if choices else AssistantMessage()
generated_images = sum(len(choice.message.images) for choice in choices)
usage = extract_usage(
get_value_from(response, "usage"),
generated_images=generated_images if generated_images else None,
)
usage = extract_usage(get_value_from(response, "usage"), generated_images=len(images) if images else None)
usage = fill_reasoning_token_count_from_content(usage, assistant_message.reasoning_content)
return ChatCompletionResponse(message=assistant_message, usage=usage, raw=response)
return ChatCompletionResponse(message=assistant_message, usage=usage, raw=response, choices=choices)


async def aparse_chat_completion_response(response: Any) -> ChatCompletionResponse:
first_choice = get_first_value_or_none(get_value_from(response, "choices"))
message = get_value_from(first_choice, "message")
tool_calls = extract_tool_calls(get_value_from(message, "tool_calls"))
images = await aextract_images_from_chat_message(message)
assistant_message = AssistantMessage(
choices = [
await aparse_chat_completion_choice(choice)
for choice in normalize_choice_list(get_value_from(response, "choices"))
]
assistant_message = choices[0].message if choices else AssistantMessage()
generated_images = sum(len(choice.message.images) for choice in choices)
usage = extract_usage(
get_value_from(response, "usage"),
generated_images=generated_images if generated_images else None,
)
usage = fill_reasoning_token_count_from_content(usage, assistant_message.reasoning_content)
return ChatCompletionResponse(message=assistant_message, usage=usage, raw=response, choices=choices)


def normalize_choice_list(raw_choices: Any) -> list[Any]:
if raw_choices is None:
return []
if isinstance(raw_choices, list):
return raw_choices
return [raw_choices]


def parse_chat_completion_choice(choice: Any) -> ChatCompletionChoice:
message = get_value_from(choice, "message")
return ChatCompletionChoice(
message=parse_assistant_message(message, images=extract_images_from_chat_message(message)),
index=parse_choice_index(get_value_from(choice, "index")),
finish_reason=parse_choice_finish_reason(get_value_from(choice, "finish_reason")),
)


async def aparse_chat_completion_choice(choice: Any) -> ChatCompletionChoice:
message = get_value_from(choice, "message")
return ChatCompletionChoice(
message=parse_assistant_message(message, images=await aextract_images_from_chat_message(message)),
index=parse_choice_index(get_value_from(choice, "index")),
finish_reason=parse_choice_finish_reason(get_value_from(choice, "finish_reason")),
)


def parse_assistant_message(message: Any, *, images: list[ImagePayload]) -> AssistantMessage:
return AssistantMessage(
content=coerce_message_content(get_value_from(message, "content")),
reasoning_content=extract_reasoning_content(message),
tool_calls=tool_calls,
tool_calls=extract_tool_calls(get_value_from(message, "tool_calls")),
images=images,
)
usage = extract_usage(get_value_from(response, "usage"), generated_images=len(images) if images else None)
usage = fill_reasoning_token_count_from_content(usage, assistant_message.reasoning_content)
return ChatCompletionResponse(message=assistant_message, usage=usage, raw=response)


def parse_choice_index(index: Any) -> int | None:
return index if isinstance(index, int) else None


def parse_choice_finish_reason(finish_reason: Any) -> str | None:
return finish_reason if isinstance(finish_reason, str) else None


# ---------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class ChatCompletionRequest:
model: str
messages: list[dict[str, Any]]
tools: list[dict[str, Any]] | None = None
n: int | None = None
temperature: float | None = None
top_p: float | None = None
max_tokens: int | None = None
Expand All @@ -74,11 +75,27 @@ class ChatCompletionRequest:
extra_headers: dict[str, str] | None = None


@dataclass
class ChatCompletionChoice:
message: AssistantMessage
index: int | None = None
finish_reason: str | None = None


@dataclass
class ChatCompletionResponse:
message: AssistantMessage
usage: Usage | None = None
raw: Any | None = None
choices: list[ChatCompletionChoice] = field(default_factory=list)

def __post_init__(self) -> None:
if not self.choices:
self.choices = [ChatCompletionChoice(message=self.message)]

@property
def messages(self) -> list[AssistantMessage]:
return [choice.message for choice in self.choices]


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def _build_generation_validation_error(summary: str, exc: ParserException) -> Ge
{
"temperature",
"top_p",
"n",
"max_tokens",
"stop",
"seed",
Expand Down Expand Up @@ -198,7 +199,12 @@ def consolidate_kwargs(self, **kwargs: Any) -> dict[str, Any]:
# --- completion / acompletion ---

def completion(
self, messages: list[ChatMessage], skip_usage_tracking: bool = False, **kwargs: Any
self,
messages: list[ChatMessage],
skip_usage_tracking: bool = False,
*,
allow_multiple_choices: bool = True,
**kwargs: Any,
) -> ChatCompletionResponse:
message_payloads = [message.to_dict() for message in messages]
logger.debug(
Expand All @@ -207,6 +213,8 @@ def completion(
)
response = None
kwargs = self.consolidate_kwargs(**kwargs)
if not allow_multiple_choices:
kwargs = self._drop_multi_choice_request_fields(kwargs)
try:
request = self._build_chat_completion_request(message_payloads, kwargs)
response = self._client.completion(request)
Expand All @@ -228,7 +236,12 @@ def completion(
)

async def acompletion(
self, messages: list[ChatMessage], skip_usage_tracking: bool = False, **kwargs: Any
self,
messages: list[ChatMessage],
skip_usage_tracking: bool = False,
*,
allow_multiple_choices: bool = True,
**kwargs: Any,
) -> ChatCompletionResponse:
message_payloads = [message.to_dict() for message in messages]
logger.debug(
Expand All @@ -237,6 +250,8 @@ async def acompletion(
)
response = None
kwargs = self.consolidate_kwargs(**kwargs)
if not allow_multiple_choices:
kwargs = self._drop_multi_choice_request_fields(kwargs)
try:
request = self._build_chat_completion_request(message_payloads, kwargs)
response = await self._client.acompletion(request)
Expand Down Expand Up @@ -346,12 +361,14 @@ def generate(

while True:
completion_kwargs = dict(kwargs)
completion_kwargs.pop("allow_multiple_choices", None)
if tool_schemas is not None:
completion_kwargs["tools"] = tool_schemas

completion_response = self.completion(
messages,
skip_usage_tracking=skip_usage_tracking,
allow_multiple_choices=False,
**completion_kwargs,
)

Expand Down Expand Up @@ -451,12 +468,14 @@ async def agenerate(

while True:
completion_kwargs = dict(kwargs)
completion_kwargs.pop("allow_multiple_choices", None)
if tool_schemas is not None:
completion_kwargs["tools"] = tool_schemas

completion_response = await self.acompletion(
messages,
skip_usage_tracking=skip_usage_tracking,
allow_multiple_choices=False,
**completion_kwargs,
)

Expand Down Expand Up @@ -719,6 +738,23 @@ async def aclose(self) -> None:

# --- private helpers ---

@staticmethod
def _drop_multi_choice_request_fields(kwargs: dict[str, Any]) -> dict[str, Any]:
"""Remove request controls that would make a single-result API discard choices."""
sanitized = dict(kwargs)
sanitized.pop("n", None)

extra_body = sanitized.get("extra_body")
if isinstance(extra_body, dict) and "n" in extra_body:
extra_body = dict(extra_body)
extra_body.pop("n", None)
if extra_body:
sanitized["extra_body"] = extra_body
else:
sanitized.pop("extra_body", None)

return sanitized

def _get_mcp_facade(self, tool_alias: str | None) -> MCPFacade | None:
if tool_alias is None:
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,11 +401,12 @@ def test_completion_excludes_openai_specific_params() -> None:
frequency_penalty=0.5,
presence_penalty=0.5,
seed=42,
n=4,
)
client.completion(request)

payload = sync_mock.post.call_args.kwargs["json"]
for field in ("response_format", "frequency_penalty", "presence_penalty", "seed"):
for field in ("response_format", "frequency_penalty", "presence_penalty", "seed", "n"):
assert field not in payload, f"{field!r} should be excluded from Anthropic payload"


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def test_completion_posts_to_chat_completions_route() -> None:
request = ChatCompletionRequest(
model=MODEL,
messages=[{"role": "user", "content": "Hi"}],
n=4,
temperature=0.7,
extra_body={"seed": 42},
extra_headers={"X-Trace": "1"},
Expand All @@ -125,6 +126,7 @@ def test_completion_posts_to_chat_completions_route() -> None:
assert "/chat/completions" in call_args.args[0]
payload = call_args.kwargs["json"]
assert payload["model"] == MODEL
assert payload["n"] == 4
assert payload["temperature"] == 0.7
assert payload["seed"] == 42
assert "timeout" not in payload
Expand Down
Loading
Loading