diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py index df9afc48f..0e7d7907a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py @@ -17,6 +17,7 @@ from data_designer.engine.models.clients.throttled import ThrottledModelClient from data_designer.engine.models.clients.types import ( AssistantMessage, + ChatCompletionChoice, ChatCompletionRequest, ChatCompletionResponse, EmbeddingRequest, @@ -31,6 +32,7 @@ __all__ = [ "AssistantMessage", + "ChatCompletionChoice", "ChatCompletionRequest", "ChatCompletionResponse", "EmbeddingRequest", diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic.py index 2424d3f8c..17abd8a88 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic.py @@ -44,6 +44,7 @@ class AnthropicClient(HttpModelClient): "stop", "max_tokens", "tools", + "n", "response_format", "frequency_penalty", "presence_penalty", diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py index 9a2e1cabf..02a3223a2 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py @@ -19,6 +19,7 @@ ) from data_designer.engine.models.clients.types import ( AssistantMessage, + ChatCompletionChoice, ChatCompletionResponse, ImagePayload, ToolCall, @@ -36,35 +37,75 @@ def parse_chat_completion_response(response: Any) -> ChatCompletionResponse: - first_choice = get_first_value_or_none(get_value_from(response, "choices")) - message = get_value_from(first_choice, "message") - tool_calls = extract_tool_calls(get_value_from(message, "tool_calls")) - images = extract_images_from_chat_message(message) - assistant_message = AssistantMessage( - content=coerce_message_content(get_value_from(message, "content")), - reasoning_content=extract_reasoning_content(message), - tool_calls=tool_calls, - images=images, + choices = [ + parse_chat_completion_choice(choice) for choice in normalize_choice_list(get_value_from(response, "choices")) + ] + assistant_message = choices[0].message if choices else AssistantMessage() + generated_images = sum(len(choice.message.images) for choice in choices) + usage = extract_usage( + get_value_from(response, "usage"), + generated_images=generated_images if generated_images else None, ) - usage = extract_usage(get_value_from(response, "usage"), generated_images=len(images) if images else None) usage = fill_reasoning_token_count_from_content(usage, assistant_message.reasoning_content) - return ChatCompletionResponse(message=assistant_message, usage=usage, raw=response) + return ChatCompletionResponse(message=assistant_message, usage=usage, raw=response, choices=choices) async def aparse_chat_completion_response(response: Any) -> ChatCompletionResponse: - first_choice = get_first_value_or_none(get_value_from(response, "choices")) - message = get_value_from(first_choice, "message") - tool_calls = extract_tool_calls(get_value_from(message, "tool_calls")) - images = await aextract_images_from_chat_message(message) - assistant_message = AssistantMessage( + choices = [ + await aparse_chat_completion_choice(choice) + for choice in normalize_choice_list(get_value_from(response, "choices")) + ] + assistant_message = choices[0].message if choices else AssistantMessage() + generated_images = sum(len(choice.message.images) for choice in choices) + usage = extract_usage( + get_value_from(response, "usage"), + generated_images=generated_images if generated_images else None, + ) + usage = fill_reasoning_token_count_from_content(usage, assistant_message.reasoning_content) + return ChatCompletionResponse(message=assistant_message, usage=usage, raw=response, choices=choices) + + +def normalize_choice_list(raw_choices: Any) -> list[Any]: + if raw_choices is None: + return [] + if isinstance(raw_choices, list): + return raw_choices + return [raw_choices] + + +def parse_chat_completion_choice(choice: Any) -> ChatCompletionChoice: + message = get_value_from(choice, "message") + return ChatCompletionChoice( + message=parse_assistant_message(message, images=extract_images_from_chat_message(message)), + index=parse_choice_index(get_value_from(choice, "index")), + finish_reason=parse_choice_finish_reason(get_value_from(choice, "finish_reason")), + ) + + +async def aparse_chat_completion_choice(choice: Any) -> ChatCompletionChoice: + message = get_value_from(choice, "message") + return ChatCompletionChoice( + message=parse_assistant_message(message, images=await aextract_images_from_chat_message(message)), + index=parse_choice_index(get_value_from(choice, "index")), + finish_reason=parse_choice_finish_reason(get_value_from(choice, "finish_reason")), + ) + + +def parse_assistant_message(message: Any, *, images: list[ImagePayload]) -> AssistantMessage: + return AssistantMessage( content=coerce_message_content(get_value_from(message, "content")), reasoning_content=extract_reasoning_content(message), - tool_calls=tool_calls, + tool_calls=extract_tool_calls(get_value_from(message, "tool_calls")), images=images, ) - usage = extract_usage(get_value_from(response, "usage"), generated_images=len(images) if images else None) - usage = fill_reasoning_token_count_from_content(usage, assistant_message.reasoning_content) - return ChatCompletionResponse(message=assistant_message, usage=usage, raw=response) + + +def parse_choice_index(index: Any) -> int | None: + return index if isinstance(index, int) else None + + +def parse_choice_finish_reason(finish_reason: Any) -> str | None: + return finish_reason if isinstance(finish_reason, str) else None # --------------------------------------------------------------------------- diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py index 58db1f944..d28f5f16b 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py @@ -61,6 +61,7 @@ class ChatCompletionRequest: model: str messages: list[dict[str, Any]] tools: list[dict[str, Any]] | None = None + n: int | None = None temperature: float | None = None top_p: float | None = None max_tokens: int | None = None @@ -74,11 +75,27 @@ class ChatCompletionRequest: extra_headers: dict[str, str] | None = None +@dataclass +class ChatCompletionChoice: + message: AssistantMessage + index: int | None = None + finish_reason: str | None = None + + @dataclass class ChatCompletionResponse: message: AssistantMessage usage: Usage | None = None raw: Any | None = None + choices: list[ChatCompletionChoice] = field(default_factory=list) + + def __post_init__(self) -> None: + if not self.choices: + self.choices = [ChatCompletionChoice(message=self.message)] + + @property + def messages(self) -> list[AssistantMessage]: + return [choice.message for choice in self.choices] @dataclass diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index 073f62ffa..81a935282 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -85,6 +85,7 @@ def _build_generation_validation_error(summary: str, exc: ParserException) -> Ge { "temperature", "top_p", + "n", "max_tokens", "stop", "seed", @@ -198,7 +199,12 @@ def consolidate_kwargs(self, **kwargs: Any) -> dict[str, Any]: # --- completion / acompletion --- def completion( - self, messages: list[ChatMessage], skip_usage_tracking: bool = False, **kwargs: Any + self, + messages: list[ChatMessage], + skip_usage_tracking: bool = False, + *, + allow_multiple_choices: bool = True, + **kwargs: Any, ) -> ChatCompletionResponse: message_payloads = [message.to_dict() for message in messages] logger.debug( @@ -207,6 +213,8 @@ def completion( ) response = None kwargs = self.consolidate_kwargs(**kwargs) + if not allow_multiple_choices: + kwargs = self._drop_multi_choice_request_fields(kwargs) try: request = self._build_chat_completion_request(message_payloads, kwargs) response = self._client.completion(request) @@ -228,7 +236,12 @@ def completion( ) async def acompletion( - self, messages: list[ChatMessage], skip_usage_tracking: bool = False, **kwargs: Any + self, + messages: list[ChatMessage], + skip_usage_tracking: bool = False, + *, + allow_multiple_choices: bool = True, + **kwargs: Any, ) -> ChatCompletionResponse: message_payloads = [message.to_dict() for message in messages] logger.debug( @@ -237,6 +250,8 @@ async def acompletion( ) response = None kwargs = self.consolidate_kwargs(**kwargs) + if not allow_multiple_choices: + kwargs = self._drop_multi_choice_request_fields(kwargs) try: request = self._build_chat_completion_request(message_payloads, kwargs) response = await self._client.acompletion(request) @@ -346,12 +361,14 @@ def generate( while True: completion_kwargs = dict(kwargs) + completion_kwargs.pop("allow_multiple_choices", None) if tool_schemas is not None: completion_kwargs["tools"] = tool_schemas completion_response = self.completion( messages, skip_usage_tracking=skip_usage_tracking, + allow_multiple_choices=False, **completion_kwargs, ) @@ -451,12 +468,14 @@ async def agenerate( while True: completion_kwargs = dict(kwargs) + completion_kwargs.pop("allow_multiple_choices", None) if tool_schemas is not None: completion_kwargs["tools"] = tool_schemas completion_response = await self.acompletion( messages, skip_usage_tracking=skip_usage_tracking, + allow_multiple_choices=False, **completion_kwargs, ) @@ -719,6 +738,23 @@ async def aclose(self) -> None: # --- private helpers --- + @staticmethod + def _drop_multi_choice_request_fields(kwargs: dict[str, Any]) -> dict[str, Any]: + """Remove request controls that would make a single-result API discard choices.""" + sanitized = dict(kwargs) + sanitized.pop("n", None) + + extra_body = sanitized.get("extra_body") + if isinstance(extra_body, dict) and "n" in extra_body: + extra_body = dict(extra_body) + extra_body.pop("n", None) + if extra_body: + sanitized["extra_body"] = extra_body + else: + sanitized.pop("extra_body", None) + + return sanitized + def _get_mcp_facade(self, tool_alias: str | None) -> MCPFacade | None: if tool_alias is None: return None diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_anthropic.py b/packages/data-designer-engine/tests/engine/models/clients/test_anthropic.py index 66f6d85f7..1b1022d15 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_anthropic.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_anthropic.py @@ -401,11 +401,12 @@ def test_completion_excludes_openai_specific_params() -> None: frequency_penalty=0.5, presence_penalty=0.5, seed=42, + n=4, ) client.completion(request) payload = sync_mock.post.call_args.kwargs["json"] - for field in ("response_format", "frequency_penalty", "presence_penalty", "seed"): + for field in ("response_format", "frequency_penalty", "presence_penalty", "seed", "n"): assert field not in payload, f"{field!r} should be excluded from Anthropic payload" diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py b/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py index ed2faf9d4..185117978 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py @@ -115,6 +115,7 @@ def test_completion_posts_to_chat_completions_route() -> None: request = ChatCompletionRequest( model=MODEL, messages=[{"role": "user", "content": "Hi"}], + n=4, temperature=0.7, extra_body={"seed": 42}, extra_headers={"X-Trace": "1"}, @@ -125,6 +126,7 @@ def test_completion_posts_to_chat_completions_route() -> None: assert "/chat/completions" in call_args.args[0] payload = call_args.kwargs["json"] assert payload["model"] == MODEL + assert payload["n"] == 4 assert payload["temperature"] == 0.7 assert payload["seed"] == 42 assert "timeout" not in payload diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py b/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py index c9c13de9a..bc8ce6a46 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py @@ -6,6 +6,7 @@ import pytest from data_designer.engine.models.clients.parsing import ( + aparse_chat_completion_response, extract_reasoning_content, extract_tool_calls, extract_usage, @@ -13,7 +14,9 @@ parse_chat_completion_response, ) from data_designer.engine.models.clients.types import ( + AssistantMessage, ChatCompletionRequest, + ChatCompletionResponse, EmbeddingRequest, ImageGenerationRequest, TransportKwargs, @@ -21,6 +24,69 @@ ) from data_designer.engine.models.usage import TokenCountSource +# --- ChatCompletionResponse compatibility --- + + +def test_chat_completion_response_exposes_choices_for_single_message() -> None: + message = AssistantMessage(content="ok") + response = ChatCompletionResponse(message=message) + + assert response.message is message + assert response.choices[0].message is message + assert response.messages == [message] + + +def test_parse_chat_completion_response_preserves_all_choices() -> None: + response = parse_chat_completion_response( + { + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "first"}, + "finish_reason": "stop", + }, + { + "index": 1, + "message": {"role": "assistant", "content": "second"}, + "finish_reason": "length", + }, + ], + "usage": {"prompt_tokens": 3, "completion_tokens": 4, "total_tokens": 7}, + } + ) + + assert response.message.content == "first" + assert [choice.message.content for choice in response.choices] == ["first", "second"] + assert [choice.index for choice in response.choices] == [0, 1] + assert [choice.finish_reason for choice in response.choices] == ["stop", "length"] + assert [message.content for message in response.messages] == ["first", "second"] + + +@pytest.mark.asyncio +async def test_aparse_chat_completion_response_preserves_all_choices() -> None: + response = await aparse_chat_completion_response( + { + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "first"}, + "finish_reason": "stop", + }, + { + "index": 1, + "message": {"role": "assistant", "content": "second"}, + "finish_reason": "length", + }, + ], + } + ) + + assert response.message.content == "first" + assert [choice.message.content for choice in response.choices] == ["first", "second"] + assert [choice.index for choice in response.choices] == [0, 1] + assert [choice.finish_reason for choice in response.choices] == ["stop", "length"] + + # --- TransportKwargs.from_request: extra_body flattening (default) --- @@ -39,6 +105,13 @@ def test_extra_body_keys_are_flattened_into_body() -> None: assert "extra_body" not in transport.body +def test_chat_completion_request_n_is_forwarded_into_body() -> None: + request = ChatCompletionRequest(model="m", messages=[], n=4) + transport = TransportKwargs.from_request(request) + + assert transport.body["n"] == 4 + + def test_extra_body_none_produces_no_extra_keys() -> None: request = ChatCompletionRequest(model="m", messages=[], temperature=0.5) transport = TransportKwargs.from_request(request) diff --git a/packages/data-designer-engine/tests/engine/models/test_facade.py b/packages/data-designer-engine/tests/engine/models/test_facade.py index 89abd74f5..4587f0722 100644 --- a/packages/data-designer-engine/tests/engine/models/test_facade.py +++ b/packages/data-designer-engine/tests/engine/models/test_facade.py @@ -11,6 +11,7 @@ from data_designer.engine.mcp.errors import MCPConfigurationError, MCPToolError from data_designer.engine.models.clients.types import ( AssistantMessage, + ChatCompletionRequest, ChatCompletionResponse, EmbeddingResponse, ImageGenerationResponse, @@ -31,6 +32,15 @@ def _make_response(content: str | None = None, **kwargs: Any) -> ChatCompletionR return make_stub_completion_response(content=content, **kwargs) +def _assert_no_multi_choice_request( + request: Any, + expected_extra_body: dict[str, Any] | None = None, +) -> None: + assert isinstance(request, ChatCompletionRequest) + assert request.n is None + assert request.extra_body == expected_extra_body + + @pytest.fixture def stub_model_facade( stub_model_configs: list[Any], @@ -122,6 +132,78 @@ def capture_and_return(*args: Any, **kwargs: Any) -> ChatCompletionResponse: assert captured_messages[0] == expected_messages +def test_generate_drops_n_from_single_result_request( + stub_model_facade: ModelFacade, + stub_model_client: MagicMock, +) -> None: + stub_model_client.completion.return_value = _make_response("Hello!") + + stub_model_facade.generate(prompt="does not matter", parser=lambda x: x, n=4) + + _assert_no_multi_choice_request(stub_model_client.completion.call_args.args[0]) + + +def test_generate_drops_extra_body_n_from_single_result_request( + stub_model_facade: ModelFacade, + stub_model_client: MagicMock, +) -> None: + stub_model_client.completion.return_value = _make_response("Hello!") + + stub_model_facade.generate(prompt="does not matter", parser=lambda x: x, extra_body={"n": 4, "seed": 42}) + + _assert_no_multi_choice_request( + stub_model_client.completion.call_args.args[0], + expected_extra_body={"seed": 42}, + ) + + +def test_generate_drops_configured_extra_body_n_from_single_result_request( + stub_model_configs: list[Any], + stub_model_facade: ModelFacade, + stub_model_client: MagicMock, +) -> None: + stub_model_configs[0].inference_parameters.extra_body = {"n": 4, "seed": 42} + stub_model_facade.model_provider.extra_body = {"n": 5, "provider": "kept"} + stub_model_client.completion.return_value = _make_response("Hello!") + + stub_model_facade.generate(prompt="does not matter", parser=lambda x: x) + + _assert_no_multi_choice_request( + stub_model_client.completion.call_args.args[0], + expected_extra_body={"seed": 42, "provider": "kept"}, + ) + + +@pytest.mark.asyncio +async def test_agenerate_drops_n_from_single_result_request( + stub_model_facade: ModelFacade, + stub_model_client: MagicMock, +) -> None: + stub_model_client.acompletion = AsyncMock(return_value=_make_response("Hello!")) + + await stub_model_facade.agenerate(prompt="does not matter", parser=lambda x: x, n=4) + + _assert_no_multi_choice_request(stub_model_client.acompletion.call_args.args[0]) + + +@pytest.mark.asyncio +async def test_agenerate_drops_configured_extra_body_n_from_single_result_request( + stub_model_configs: list[Any], + stub_model_facade: ModelFacade, + stub_model_client: MagicMock, +) -> None: + stub_model_configs[0].inference_parameters.extra_body = {"n": 4, "seed": 42} + stub_model_facade.model_provider.extra_body = {"n": 5, "provider": "kept"} + stub_model_client.acompletion = AsyncMock(return_value=_make_response("Hello!")) + + await stub_model_facade.agenerate(prompt="does not matter", parser=lambda x: x) + + _assert_no_multi_choice_request( + stub_model_client.acompletion.call_args.args[0], + expected_extra_body={"seed": 42, "provider": "kept"}, + ) + + @patch.object(ModelFacade, "completion", autospec=True) def test_generate_includes_parser_validation_detail_in_user_facing_error( mock_completion: Any, @@ -421,6 +503,21 @@ def test_completion_with_kwargs( assert stub_model_client.completion.call_count == 1 +def test_completion_forwards_n_to_request( + stub_completion_messages: list[ChatMessage], + stub_model_facade: ModelFacade, + stub_model_client: MagicMock, +) -> None: + expected_response = _make_response("Test response") + stub_model_client.completion.return_value = expected_response + + stub_model_facade.completion(stub_completion_messages, n=4) + + request = stub_model_client.completion.call_args.args[0] + assert isinstance(request, ChatCompletionRequest) + assert request.n == 4 + + def test_generate_text_embeddings_success( stub_model_facade: ModelFacade, stub_model_client: MagicMock,