From dd1923bd8ff2ef7848cd070517504d838eb13dec Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Mon, 18 May 2026 10:24:11 -0600 Subject: [PATCH 1/6] fix chat completion multi-choice support Restore the chat completion n request field and preserve all returned choices in the canonical response while keeping response.message as the first choice. Add coverage for request forwarding, compatibility access, multi-choice parsing, and generate forwarding. Fixes #620 Signed-off-by: Nabin Mulepati --- .../engine/models/clients/__init__.py | 2 + .../models/clients/adapters/anthropic.py | 1 + .../engine/models/clients/parsing.py | 81 ++++++++++++++----- .../engine/models/clients/types.py | 17 ++++ .../src/data_designer/engine/models/facade.py | 1 + .../models/clients/test_openai_compatible.py | 2 + .../engine/models/clients/test_parsing.py | 47 +++++++++++ .../tests/engine/models/test_facade.py | 28 +++++++ 8 files changed, 159 insertions(+), 20 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py index df9afc48f..0e7d7907a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py @@ -17,6 +17,7 @@ from data_designer.engine.models.clients.throttled import ThrottledModelClient from data_designer.engine.models.clients.types import ( AssistantMessage, + ChatCompletionChoice, ChatCompletionRequest, ChatCompletionResponse, EmbeddingRequest, @@ -31,6 +32,7 @@ __all__ = [ "AssistantMessage", + "ChatCompletionChoice", "ChatCompletionRequest", "ChatCompletionResponse", "EmbeddingRequest", diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic.py index 2424d3f8c..17abd8a88 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic.py @@ -44,6 +44,7 @@ class AnthropicClient(HttpModelClient): "stop", "max_tokens", "tools", + "n", "response_format", "frequency_penalty", "presence_penalty", diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py index 9a2e1cabf..02a3223a2 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py @@ -19,6 +19,7 @@ ) from data_designer.engine.models.clients.types import ( AssistantMessage, + ChatCompletionChoice, ChatCompletionResponse, ImagePayload, ToolCall, @@ -36,35 +37,75 @@ def parse_chat_completion_response(response: Any) -> ChatCompletionResponse: - first_choice = get_first_value_or_none(get_value_from(response, "choices")) - message = get_value_from(first_choice, "message") - tool_calls = extract_tool_calls(get_value_from(message, "tool_calls")) - images = extract_images_from_chat_message(message) - assistant_message = AssistantMessage( - content=coerce_message_content(get_value_from(message, "content")), - reasoning_content=extract_reasoning_content(message), - tool_calls=tool_calls, - images=images, + choices = [ + parse_chat_completion_choice(choice) for choice in normalize_choice_list(get_value_from(response, "choices")) + ] + assistant_message = choices[0].message if choices else AssistantMessage() + generated_images = sum(len(choice.message.images) for choice in choices) + usage = extract_usage( + get_value_from(response, "usage"), + generated_images=generated_images if generated_images else None, ) - usage = extract_usage(get_value_from(response, "usage"), generated_images=len(images) if images else None) usage = fill_reasoning_token_count_from_content(usage, assistant_message.reasoning_content) - return ChatCompletionResponse(message=assistant_message, usage=usage, raw=response) + return ChatCompletionResponse(message=assistant_message, usage=usage, raw=response, choices=choices) async def aparse_chat_completion_response(response: Any) -> ChatCompletionResponse: - first_choice = get_first_value_or_none(get_value_from(response, "choices")) - message = get_value_from(first_choice, "message") - tool_calls = extract_tool_calls(get_value_from(message, "tool_calls")) - images = await aextract_images_from_chat_message(message) - assistant_message = AssistantMessage( + choices = [ + await aparse_chat_completion_choice(choice) + for choice in normalize_choice_list(get_value_from(response, "choices")) + ] + assistant_message = choices[0].message if choices else AssistantMessage() + generated_images = sum(len(choice.message.images) for choice in choices) + usage = extract_usage( + get_value_from(response, "usage"), + generated_images=generated_images if generated_images else None, + ) + usage = fill_reasoning_token_count_from_content(usage, assistant_message.reasoning_content) + return ChatCompletionResponse(message=assistant_message, usage=usage, raw=response, choices=choices) + + +def normalize_choice_list(raw_choices: Any) -> list[Any]: + if raw_choices is None: + return [] + if isinstance(raw_choices, list): + return raw_choices + return [raw_choices] + + +def parse_chat_completion_choice(choice: Any) -> ChatCompletionChoice: + message = get_value_from(choice, "message") + return ChatCompletionChoice( + message=parse_assistant_message(message, images=extract_images_from_chat_message(message)), + index=parse_choice_index(get_value_from(choice, "index")), + finish_reason=parse_choice_finish_reason(get_value_from(choice, "finish_reason")), + ) + + +async def aparse_chat_completion_choice(choice: Any) -> ChatCompletionChoice: + message = get_value_from(choice, "message") + return ChatCompletionChoice( + message=parse_assistant_message(message, images=await aextract_images_from_chat_message(message)), + index=parse_choice_index(get_value_from(choice, "index")), + finish_reason=parse_choice_finish_reason(get_value_from(choice, "finish_reason")), + ) + + +def parse_assistant_message(message: Any, *, images: list[ImagePayload]) -> AssistantMessage: + return AssistantMessage( content=coerce_message_content(get_value_from(message, "content")), reasoning_content=extract_reasoning_content(message), - tool_calls=tool_calls, + tool_calls=extract_tool_calls(get_value_from(message, "tool_calls")), images=images, ) - usage = extract_usage(get_value_from(response, "usage"), generated_images=len(images) if images else None) - usage = fill_reasoning_token_count_from_content(usage, assistant_message.reasoning_content) - return ChatCompletionResponse(message=assistant_message, usage=usage, raw=response) + + +def parse_choice_index(index: Any) -> int | None: + return index if isinstance(index, int) else None + + +def parse_choice_finish_reason(finish_reason: Any) -> str | None: + return finish_reason if isinstance(finish_reason, str) else None # --------------------------------------------------------------------------- diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py index 58db1f944..d28f5f16b 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py @@ -61,6 +61,7 @@ class ChatCompletionRequest: model: str messages: list[dict[str, Any]] tools: list[dict[str, Any]] | None = None + n: int | None = None temperature: float | None = None top_p: float | None = None max_tokens: int | None = None @@ -74,11 +75,27 @@ class ChatCompletionRequest: extra_headers: dict[str, str] | None = None +@dataclass +class ChatCompletionChoice: + message: AssistantMessage + index: int | None = None + finish_reason: str | None = None + + @dataclass class ChatCompletionResponse: message: AssistantMessage usage: Usage | None = None raw: Any | None = None + choices: list[ChatCompletionChoice] = field(default_factory=list) + + def __post_init__(self) -> None: + if not self.choices: + self.choices = [ChatCompletionChoice(message=self.message)] + + @property + def messages(self) -> list[AssistantMessage]: + return [choice.message for choice in self.choices] @dataclass diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index 073f62ffa..49ebb6698 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -85,6 +85,7 @@ def _build_generation_validation_error(summary: str, exc: ParserException) -> Ge { "temperature", "top_p", + "n", "max_tokens", "stop", "seed", diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py b/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py index ed2faf9d4..185117978 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py @@ -115,6 +115,7 @@ def test_completion_posts_to_chat_completions_route() -> None: request = ChatCompletionRequest( model=MODEL, messages=[{"role": "user", "content": "Hi"}], + n=4, temperature=0.7, extra_body={"seed": 42}, extra_headers={"X-Trace": "1"}, @@ -125,6 +126,7 @@ def test_completion_posts_to_chat_completions_route() -> None: assert "/chat/completions" in call_args.args[0] payload = call_args.kwargs["json"] assert payload["model"] == MODEL + assert payload["n"] == 4 assert payload["temperature"] == 0.7 assert payload["seed"] == 42 assert "timeout" not in payload diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py b/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py index c9c13de9a..9d614ce19 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py @@ -13,7 +13,9 @@ parse_chat_completion_response, ) from data_designer.engine.models.clients.types import ( + AssistantMessage, ChatCompletionRequest, + ChatCompletionResponse, EmbeddingRequest, ImageGenerationRequest, TransportKwargs, @@ -21,6 +23,44 @@ ) from data_designer.engine.models.usage import TokenCountSource +# --- ChatCompletionResponse compatibility --- + + +def test_chat_completion_response_exposes_choices_for_single_message() -> None: + message = AssistantMessage(content="ok") + response = ChatCompletionResponse(message=message) + + assert response.message is message + assert response.choices[0].message is message + assert response.messages == [message] + + +def test_parse_chat_completion_response_preserves_all_choices() -> None: + response = parse_chat_completion_response( + { + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "first"}, + "finish_reason": "stop", + }, + { + "index": 1, + "message": {"role": "assistant", "content": "second"}, + "finish_reason": "length", + }, + ], + "usage": {"prompt_tokens": 3, "completion_tokens": 4, "total_tokens": 7}, + } + ) + + assert response.message.content == "first" + assert [choice.message.content for choice in response.choices] == ["first", "second"] + assert [choice.index for choice in response.choices] == [0, 1] + assert [choice.finish_reason for choice in response.choices] == ["stop", "length"] + assert [message.content for message in response.messages] == ["first", "second"] + + # --- TransportKwargs.from_request: extra_body flattening (default) --- @@ -39,6 +79,13 @@ def test_extra_body_keys_are_flattened_into_body() -> None: assert "extra_body" not in transport.body +def test_chat_completion_request_n_is_forwarded_into_body() -> None: + request = ChatCompletionRequest(model="m", messages=[], n=4) + transport = TransportKwargs.from_request(request) + + assert transport.body["n"] == 4 + + def test_extra_body_none_produces_no_extra_keys() -> None: request = ChatCompletionRequest(model="m", messages=[], temperature=0.5) transport = TransportKwargs.from_request(request) diff --git a/packages/data-designer-engine/tests/engine/models/test_facade.py b/packages/data-designer-engine/tests/engine/models/test_facade.py index 89abd74f5..21126fceb 100644 --- a/packages/data-designer-engine/tests/engine/models/test_facade.py +++ b/packages/data-designer-engine/tests/engine/models/test_facade.py @@ -11,6 +11,7 @@ from data_designer.engine.mcp.errors import MCPConfigurationError, MCPToolError from data_designer.engine.models.clients.types import ( AssistantMessage, + ChatCompletionRequest, ChatCompletionResponse, EmbeddingResponse, ImageGenerationResponse, @@ -122,6 +123,18 @@ def capture_and_return(*args: Any, **kwargs: Any) -> ChatCompletionResponse: assert captured_messages[0] == expected_messages +@patch.object(ModelFacade, "completion", autospec=True) +def test_generate_forwards_n_to_completion( + mock_completion: Any, + stub_model_facade: ModelFacade, +) -> None: + mock_completion.return_value = _make_response("Hello!") + + stub_model_facade.generate(prompt="does not matter", parser=lambda x: x, n=4) + + assert mock_completion.call_args.kwargs["n"] == 4 + + @patch.object(ModelFacade, "completion", autospec=True) def test_generate_includes_parser_validation_detail_in_user_facing_error( mock_completion: Any, @@ -421,6 +434,21 @@ def test_completion_with_kwargs( assert stub_model_client.completion.call_count == 1 +def test_completion_forwards_n_to_request( + stub_completion_messages: list[ChatMessage], + stub_model_facade: ModelFacade, + stub_model_client: MagicMock, +) -> None: + expected_response = _make_response("Test response") + stub_model_client.completion.return_value = expected_response + + stub_model_facade.completion(stub_completion_messages, n=4) + + request = stub_model_client.completion.call_args.args[0] + assert isinstance(request, ChatCompletionRequest) + assert request.n == 4 + + def test_generate_text_embeddings_success( stub_model_facade: ModelFacade, stub_model_client: MagicMock, From a0e0fc093c4f7ba30b292692147ef072fff57dc9 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Mon, 18 May 2026 11:37:36 -0600 Subject: [PATCH 2/6] strip n from generate requests Prevent generate and agenerate from forwarding multi-choice requests that they cannot expose, while keeping completion() multi-choice support intact. Add coverage for async parsing and Anthropic n exclusion. Signed-off-by: Nabin Mulepati --- .../src/data_designer/engine/models/facade.py | 2 ++ .../engine/models/clients/test_anthropic.py | 3 ++- .../engine/models/clients/test_parsing.py | 26 +++++++++++++++++++ .../tests/engine/models/test_facade.py | 17 ++++++++++-- 4 files changed, 45 insertions(+), 3 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index 49ebb6698..c46a61c35 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -347,6 +347,7 @@ def generate( while True: completion_kwargs = dict(kwargs) + completion_kwargs.pop("n", None) if tool_schemas is not None: completion_kwargs["tools"] = tool_schemas @@ -452,6 +453,7 @@ async def agenerate( while True: completion_kwargs = dict(kwargs) + completion_kwargs.pop("n", None) if tool_schemas is not None: completion_kwargs["tools"] = tool_schemas diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_anthropic.py b/packages/data-designer-engine/tests/engine/models/clients/test_anthropic.py index 66f6d85f7..1b1022d15 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_anthropic.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_anthropic.py @@ -401,11 +401,12 @@ def test_completion_excludes_openai_specific_params() -> None: frequency_penalty=0.5, presence_penalty=0.5, seed=42, + n=4, ) client.completion(request) payload = sync_mock.post.call_args.kwargs["json"] - for field in ("response_format", "frequency_penalty", "presence_penalty", "seed"): + for field in ("response_format", "frequency_penalty", "presence_penalty", "seed", "n"): assert field not in payload, f"{field!r} should be excluded from Anthropic payload" diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py b/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py index 9d614ce19..bc8ce6a46 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py @@ -6,6 +6,7 @@ import pytest from data_designer.engine.models.clients.parsing import ( + aparse_chat_completion_response, extract_reasoning_content, extract_tool_calls, extract_usage, @@ -61,6 +62,31 @@ def test_parse_chat_completion_response_preserves_all_choices() -> None: assert [message.content for message in response.messages] == ["first", "second"] +@pytest.mark.asyncio +async def test_aparse_chat_completion_response_preserves_all_choices() -> None: + response = await aparse_chat_completion_response( + { + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "first"}, + "finish_reason": "stop", + }, + { + "index": 1, + "message": {"role": "assistant", "content": "second"}, + "finish_reason": "length", + }, + ], + } + ) + + assert response.message.content == "first" + assert [choice.message.content for choice in response.choices] == ["first", "second"] + assert [choice.index for choice in response.choices] == [0, 1] + assert [choice.finish_reason for choice in response.choices] == ["stop", "length"] + + # --- TransportKwargs.from_request: extra_body flattening (default) --- diff --git a/packages/data-designer-engine/tests/engine/models/test_facade.py b/packages/data-designer-engine/tests/engine/models/test_facade.py index 21126fceb..21e21cba1 100644 --- a/packages/data-designer-engine/tests/engine/models/test_facade.py +++ b/packages/data-designer-engine/tests/engine/models/test_facade.py @@ -124,7 +124,7 @@ def capture_and_return(*args: Any, **kwargs: Any) -> ChatCompletionResponse: @patch.object(ModelFacade, "completion", autospec=True) -def test_generate_forwards_n_to_completion( +def test_generate_drops_n_before_completion( mock_completion: Any, stub_model_facade: ModelFacade, ) -> None: @@ -132,7 +132,20 @@ def test_generate_forwards_n_to_completion( stub_model_facade.generate(prompt="does not matter", parser=lambda x: x, n=4) - assert mock_completion.call_args.kwargs["n"] == 4 + assert "n" not in mock_completion.call_args.kwargs + + +@patch.object(ModelFacade, "acompletion", new_callable=AsyncMock) +@pytest.mark.asyncio +async def test_agenerate_drops_n_before_acompletion( + mock_acompletion: AsyncMock, + stub_model_facade: ModelFacade, +) -> None: + mock_acompletion.return_value = _make_response("Hello!") + + await stub_model_facade.agenerate(prompt="does not matter", parser=lambda x: x, n=4) + + assert "n" not in mock_acompletion.call_args.kwargs @patch.object(ModelFacade, "completion", autospec=True) From 100d91980637565251471efe075da607b1d769cb Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Wed, 20 May 2026 09:29:22 -0600 Subject: [PATCH 3/6] strip configured n from generate requests Signed-off-by: Nabin Mulepati --- .../src/data_designer/engine/models/facade.py | 39 +++++++++- .../tests/engine/models/test_facade.py | 76 ++++++++++++++++--- 2 files changed, 101 insertions(+), 14 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index c46a61c35..4a98e9596 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -55,6 +55,23 @@ def _identity(x: Any) -> Any: return x +def _drop_multi_choice_request_fields(kwargs: dict[str, Any]) -> dict[str, Any]: + """Remove request controls that would make a single-result API discard choices.""" + sanitized = dict(kwargs) + sanitized.pop("n", None) + + extra_body = sanitized.get("extra_body") + if isinstance(extra_body, dict) and "n" in extra_body: + extra_body = dict(extra_body) + extra_body.pop("n", None) + if extra_body: + sanitized["extra_body"] = extra_body + else: + sanitized.pop("extra_body", None) + + return sanitized + + logger = logging.getLogger(__name__) @@ -199,7 +216,12 @@ def consolidate_kwargs(self, **kwargs: Any) -> dict[str, Any]: # --- completion / acompletion --- def completion( - self, messages: list[ChatMessage], skip_usage_tracking: bool = False, **kwargs: Any + self, + messages: list[ChatMessage], + skip_usage_tracking: bool = False, + *, + _allow_multiple_choices: bool = True, + **kwargs: Any, ) -> ChatCompletionResponse: message_payloads = [message.to_dict() for message in messages] logger.debug( @@ -208,6 +230,8 @@ def completion( ) response = None kwargs = self.consolidate_kwargs(**kwargs) + if not _allow_multiple_choices: + kwargs = _drop_multi_choice_request_fields(kwargs) try: request = self._build_chat_completion_request(message_payloads, kwargs) response = self._client.completion(request) @@ -229,7 +253,12 @@ def completion( ) async def acompletion( - self, messages: list[ChatMessage], skip_usage_tracking: bool = False, **kwargs: Any + self, + messages: list[ChatMessage], + skip_usage_tracking: bool = False, + *, + _allow_multiple_choices: bool = True, + **kwargs: Any, ) -> ChatCompletionResponse: message_payloads = [message.to_dict() for message in messages] logger.debug( @@ -238,6 +267,8 @@ async def acompletion( ) response = None kwargs = self.consolidate_kwargs(**kwargs) + if not _allow_multiple_choices: + kwargs = _drop_multi_choice_request_fields(kwargs) try: request = self._build_chat_completion_request(message_payloads, kwargs) response = await self._client.acompletion(request) @@ -347,13 +378,13 @@ def generate( while True: completion_kwargs = dict(kwargs) - completion_kwargs.pop("n", None) if tool_schemas is not None: completion_kwargs["tools"] = tool_schemas completion_response = self.completion( messages, skip_usage_tracking=skip_usage_tracking, + _allow_multiple_choices=False, **completion_kwargs, ) @@ -453,13 +484,13 @@ async def agenerate( while True: completion_kwargs = dict(kwargs) - completion_kwargs.pop("n", None) if tool_schemas is not None: completion_kwargs["tools"] = tool_schemas completion_response = await self.acompletion( messages, skip_usage_tracking=skip_usage_tracking, + _allow_multiple_choices=False, **completion_kwargs, ) diff --git a/packages/data-designer-engine/tests/engine/models/test_facade.py b/packages/data-designer-engine/tests/engine/models/test_facade.py index 21e21cba1..4587f0722 100644 --- a/packages/data-designer-engine/tests/engine/models/test_facade.py +++ b/packages/data-designer-engine/tests/engine/models/test_facade.py @@ -32,6 +32,15 @@ def _make_response(content: str | None = None, **kwargs: Any) -> ChatCompletionR return make_stub_completion_response(content=content, **kwargs) +def _assert_no_multi_choice_request( + request: Any, + expected_extra_body: dict[str, Any] | None = None, +) -> None: + assert isinstance(request, ChatCompletionRequest) + assert request.n is None + assert request.extra_body == expected_extra_body + + @pytest.fixture def stub_model_facade( stub_model_configs: list[Any], @@ -123,29 +132,76 @@ def capture_and_return(*args: Any, **kwargs: Any) -> ChatCompletionResponse: assert captured_messages[0] == expected_messages -@patch.object(ModelFacade, "completion", autospec=True) -def test_generate_drops_n_before_completion( - mock_completion: Any, +def test_generate_drops_n_from_single_result_request( stub_model_facade: ModelFacade, + stub_model_client: MagicMock, ) -> None: - mock_completion.return_value = _make_response("Hello!") + stub_model_client.completion.return_value = _make_response("Hello!") stub_model_facade.generate(prompt="does not matter", parser=lambda x: x, n=4) - assert "n" not in mock_completion.call_args.kwargs + _assert_no_multi_choice_request(stub_model_client.completion.call_args.args[0]) + + +def test_generate_drops_extra_body_n_from_single_result_request( + stub_model_facade: ModelFacade, + stub_model_client: MagicMock, +) -> None: + stub_model_client.completion.return_value = _make_response("Hello!") + + stub_model_facade.generate(prompt="does not matter", parser=lambda x: x, extra_body={"n": 4, "seed": 42}) + + _assert_no_multi_choice_request( + stub_model_client.completion.call_args.args[0], + expected_extra_body={"seed": 42}, + ) + + +def test_generate_drops_configured_extra_body_n_from_single_result_request( + stub_model_configs: list[Any], + stub_model_facade: ModelFacade, + stub_model_client: MagicMock, +) -> None: + stub_model_configs[0].inference_parameters.extra_body = {"n": 4, "seed": 42} + stub_model_facade.model_provider.extra_body = {"n": 5, "provider": "kept"} + stub_model_client.completion.return_value = _make_response("Hello!") + + stub_model_facade.generate(prompt="does not matter", parser=lambda x: x) + + _assert_no_multi_choice_request( + stub_model_client.completion.call_args.args[0], + expected_extra_body={"seed": 42, "provider": "kept"}, + ) -@patch.object(ModelFacade, "acompletion", new_callable=AsyncMock) @pytest.mark.asyncio -async def test_agenerate_drops_n_before_acompletion( - mock_acompletion: AsyncMock, +async def test_agenerate_drops_n_from_single_result_request( stub_model_facade: ModelFacade, + stub_model_client: MagicMock, ) -> None: - mock_acompletion.return_value = _make_response("Hello!") + stub_model_client.acompletion = AsyncMock(return_value=_make_response("Hello!")) await stub_model_facade.agenerate(prompt="does not matter", parser=lambda x: x, n=4) - assert "n" not in mock_acompletion.call_args.kwargs + _assert_no_multi_choice_request(stub_model_client.acompletion.call_args.args[0]) + + +@pytest.mark.asyncio +async def test_agenerate_drops_configured_extra_body_n_from_single_result_request( + stub_model_configs: list[Any], + stub_model_facade: ModelFacade, + stub_model_client: MagicMock, +) -> None: + stub_model_configs[0].inference_parameters.extra_body = {"n": 4, "seed": 42} + stub_model_facade.model_provider.extra_body = {"n": 5, "provider": "kept"} + stub_model_client.acompletion = AsyncMock(return_value=_make_response("Hello!")) + + await stub_model_facade.agenerate(prompt="does not matter", parser=lambda x: x) + + _assert_no_multi_choice_request( + stub_model_client.acompletion.call_args.args[0], + expected_extra_body={"seed": 42, "provider": "kept"}, + ) @patch.object(ModelFacade, "completion", autospec=True) From 6074d8abc2a7232e29cdc302623e900228149c66 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Wed, 20 May 2026 09:30:40 -0600 Subject: [PATCH 4/6] rename multiple choice completion flag Signed-off-by: Nabin Mulepati --- .../src/data_designer/engine/models/facade.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index 4a98e9596..399f3e85a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -220,7 +220,7 @@ def completion( messages: list[ChatMessage], skip_usage_tracking: bool = False, *, - _allow_multiple_choices: bool = True, + allow_multiple_choices: bool = True, **kwargs: Any, ) -> ChatCompletionResponse: message_payloads = [message.to_dict() for message in messages] @@ -230,7 +230,7 @@ def completion( ) response = None kwargs = self.consolidate_kwargs(**kwargs) - if not _allow_multiple_choices: + if not allow_multiple_choices: kwargs = _drop_multi_choice_request_fields(kwargs) try: request = self._build_chat_completion_request(message_payloads, kwargs) @@ -257,7 +257,7 @@ async def acompletion( messages: list[ChatMessage], skip_usage_tracking: bool = False, *, - _allow_multiple_choices: bool = True, + allow_multiple_choices: bool = True, **kwargs: Any, ) -> ChatCompletionResponse: message_payloads = [message.to_dict() for message in messages] @@ -267,7 +267,7 @@ async def acompletion( ) response = None kwargs = self.consolidate_kwargs(**kwargs) - if not _allow_multiple_choices: + if not allow_multiple_choices: kwargs = _drop_multi_choice_request_fields(kwargs) try: request = self._build_chat_completion_request(message_payloads, kwargs) @@ -378,13 +378,14 @@ def generate( while True: completion_kwargs = dict(kwargs) + completion_kwargs.pop("allow_multiple_choices", None) if tool_schemas is not None: completion_kwargs["tools"] = tool_schemas completion_response = self.completion( messages, skip_usage_tracking=skip_usage_tracking, - _allow_multiple_choices=False, + allow_multiple_choices=False, **completion_kwargs, ) @@ -484,13 +485,14 @@ async def agenerate( while True: completion_kwargs = dict(kwargs) + completion_kwargs.pop("allow_multiple_choices", None) if tool_schemas is not None: completion_kwargs["tools"] = tool_schemas completion_response = await self.acompletion( messages, skip_usage_tracking=skip_usage_tracking, - _allow_multiple_choices=False, + allow_multiple_choices=False, **completion_kwargs, ) From 67403db6029b8d99e41e03bd768bd26dee87d013 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Wed, 20 May 2026 09:33:02 -0600 Subject: [PATCH 5/6] move choice sanitizer to private helpers Signed-off-by: Nabin Mulepati --- .../src/data_designer/engine/models/facade.py | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index 399f3e85a..abb3d829a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -55,23 +55,6 @@ def _identity(x: Any) -> Any: return x -def _drop_multi_choice_request_fields(kwargs: dict[str, Any]) -> dict[str, Any]: - """Remove request controls that would make a single-result API discard choices.""" - sanitized = dict(kwargs) - sanitized.pop("n", None) - - extra_body = sanitized.get("extra_body") - if isinstance(extra_body, dict) and "n" in extra_body: - extra_body = dict(extra_body) - extra_body.pop("n", None) - if extra_body: - sanitized["extra_body"] = extra_body - else: - sanitized.pop("extra_body", None) - - return sanitized - - logger = logging.getLogger(__name__) @@ -231,7 +214,7 @@ def completion( response = None kwargs = self.consolidate_kwargs(**kwargs) if not allow_multiple_choices: - kwargs = _drop_multi_choice_request_fields(kwargs) + kwargs = self._drop_multi_choice_request_fields(kwargs) try: request = self._build_chat_completion_request(message_payloads, kwargs) response = self._client.completion(request) @@ -268,7 +251,7 @@ async def acompletion( response = None kwargs = self.consolidate_kwargs(**kwargs) if not allow_multiple_choices: - kwargs = _drop_multi_choice_request_fields(kwargs) + kwargs = self._drop_multi_choice_request_fields(kwargs) try: request = self._build_chat_completion_request(message_payloads, kwargs) response = await self._client.acompletion(request) @@ -766,6 +749,23 @@ def _get_mcp_facade(self, tool_alias: str | None) -> MCPFacade | None: except ValueError as exc: raise MCPConfigurationError(f"Tool alias {tool_alias!r} is not registered.") from exc + @staticmethod + def _drop_multi_choice_request_fields(kwargs: dict[str, Any]) -> dict[str, Any]: + """Remove request controls that would make a single-result API discard choices.""" + sanitized = dict(kwargs) + sanitized.pop("n", None) + + extra_body = sanitized.get("extra_body") + if isinstance(extra_body, dict) and "n" in extra_body: + extra_body = dict(extra_body) + extra_body.pop("n", None) + if extra_body: + sanitized["extra_body"] = extra_body + else: + sanitized.pop("extra_body", None) + + return sanitized + def _build_chat_completion_request( self, messages: list[dict[str, Any]], kwargs: dict[str, Any] ) -> ChatCompletionRequest: From 1774e9af4084de09367d5481b94915e1a06d1c7b Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Wed, 20 May 2026 09:34:23 -0600 Subject: [PATCH 6/6] order private facade helpers Signed-off-by: Nabin Mulepati --- .../src/data_designer/engine/models/facade.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index abb3d829a..81a935282 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -738,17 +738,6 @@ async def aclose(self) -> None: # --- private helpers --- - def _get_mcp_facade(self, tool_alias: str | None) -> MCPFacade | None: - if tool_alias is None: - return None - if self._mcp_registry is None: - raise MCPConfigurationError(f"Tool alias {tool_alias!r} specified but no MCPRegistry configured.") - - try: - return self._mcp_registry.get_mcp(tool_alias=tool_alias) - except ValueError as exc: - raise MCPConfigurationError(f"Tool alias {tool_alias!r} is not registered.") from exc - @staticmethod def _drop_multi_choice_request_fields(kwargs: dict[str, Any]) -> dict[str, Any]: """Remove request controls that would make a single-result API discard choices.""" @@ -766,6 +755,17 @@ def _drop_multi_choice_request_fields(kwargs: dict[str, Any]) -> dict[str, Any]: return sanitized + def _get_mcp_facade(self, tool_alias: str | None) -> MCPFacade | None: + if tool_alias is None: + return None + if self._mcp_registry is None: + raise MCPConfigurationError(f"Tool alias {tool_alias!r} specified but no MCPRegistry configured.") + + try: + return self._mcp_registry.get_mcp(tool_alias=tool_alias) + except ValueError as exc: + raise MCPConfigurationError(f"Tool alias {tool_alias!r} is not registered.") from exc + def _build_chat_completion_request( self, messages: list[dict[str, Any]], kwargs: dict[str, Any] ) -> ChatCompletionRequest: