diff --git a/frontend/src/components/app-config/ai-config.tsx b/frontend/src/components/app-config/ai-config.tsx index 0e22a4b46fb..839d3b07b3e 100644 --- a/frontend/src/components/app-config/ai-config.tsx +++ b/frontend/src/components/app-config/ai-config.tsx @@ -1253,6 +1253,13 @@ export const AiAssistConfig: React.FC = ({ config, onSubmit, }) => { + // Tracked locally rather than derived from the field value so that clearing + // the input (a transient empty value, which commits `null`) does not disable + // the input mid-edit and force the user to re-tick the Override checkbox. + const [maxTokensEnabled, setMaxTokensEnabled] = useState( + config.ai?.max_tokens != null, + ); + return ( AI Assistant @@ -1279,6 +1286,71 @@ export const AiAssistConfig: React.FC = ({ )} /> + { + return ( +
+
+ + + Max output tokens + + + { + const n = Number.parseInt(e.target.value, 10); + field.onChange(Number.isFinite(n) && n > 0 ? n : null); + }} + /> + + + + { + const isChecked = checked === true; + setMaxTokensEnabled(isChecked); + // null signals delete to the server; cast because + // UserConfig (OpenAPI-derived) types max_tokens as + // `number | undefined`, but zod accepts `null`. + const next = ( + isChecked ? (field.value ?? 32768) : null + ) as number | undefined; + // shouldDirty: true forces RHF to keep this in + // dirtyFields even when `next` happens to equal the + // form's defaultValue (e.g. untick → tick when disk + // started with 32768). Otherwise getDirtyValues + // would skip it and the save body would be empty. + form.setValue("ai.max_tokens", next, { + shouldDirty: true, + shouldTouch: true, + }); + onSubmit(form.getValues()); + }} + /> + Override + +
+ + + Each provider sets its own max output tokens (Anthropic uses a + recommended default). Adjust to control costs or enable more + output. + +
+ ); + }} + /> + {
onSubmit(values))} className="flex text-pretty overflow-hidden" > MarimoConfig: """Merge a user configuration with a new configuration. The new config will take precedence over the default config. diff --git a/marimo/_config/manager.py b/marimo/_config/manager.py index a36d54d31bd..63755ce3570 100644 --- a/marimo/_config/manager.py +++ b/marimo/_config/manager.py @@ -400,6 +400,11 @@ def save_config( # Merge the current config with the new config current_config = self._load_config() merged = merge_config(current_config, config) + # None-as-delete: any key whose merged value is None (typically because + # the incoming config explicitly sent null) is removed from disk. Lets + # the UI clear optional scalars (e.g. ai.max_tokens) without a separate + # delete primitive. + _drop_none_values(cast(dict[str, Any], merged)) with open(config_path, "w", encoding="utf-8") as f: tomlkit.dump(merged, f, sort_keys=True) @@ -459,3 +464,13 @@ def get_config(self, *, hide_secrets: bool = True) -> PartialMarimoConfig: if hide_secrets: return mask_secrets_partial(self.override_config) return self.override_config + + +def _drop_none_values(d: dict[str, Any]) -> None: + """Recursively remove keys whose value is None, in place.""" + for key in list(d): + v = d[key] + if v is None: + del d[key] + elif isinstance(v, dict): + _drop_none_values(v) diff --git a/marimo/_server/ai/config.py b/marimo/_server/ai/config.py index 9c45f5fc3c4..4c59251b71c 100644 --- a/marimo/_server/ai/config.py +++ b/marimo/_server/ai/config.py @@ -15,7 +15,7 @@ MarimoConfig, PartialMarimoConfig, ) -from marimo._server.ai.constants import DEFAULT_MAX_TOKENS, DEFAULT_MODEL +from marimo._server.ai.constants import DEFAULT_MODEL from marimo._server.ai.ids import AiModelId from marimo._server.ai.tools.tool_manager import get_tool_manager from marimo._server.ai.tools.types import ToolDefinition @@ -346,11 +346,9 @@ def get_autocomplete_model( ) -def get_max_tokens(config: MarimoConfig) -> int: - if "ai" not in config: - return DEFAULT_MAX_TOKENS - if "max_tokens" not in config["ai"]: - return DEFAULT_MAX_TOKENS +def get_max_tokens(config: MarimoConfig) -> int | None: + if "ai" not in config or "max_tokens" not in config["ai"]: + return None return config["ai"]["max_tokens"] diff --git a/marimo/_server/ai/constants.py b/marimo/_server/ai/constants.py index e0540257854..9bd8cc8fed2 100644 --- a/marimo/_server/ai/constants.py +++ b/marimo/_server/ai/constants.py @@ -1,5 +1,5 @@ # Copyright 2026 Marimo. All rights reserved. from __future__ import annotations -DEFAULT_MAX_TOKENS = 4096 +ANTHROPIC_DEFAULT_MAX_TOKENS = 32768 DEFAULT_MODEL = "openai/gpt-4o" diff --git a/marimo/_server/ai/providers.py b/marimo/_server/ai/providers.py index 0c896a46189..cf3be6827ab 100644 --- a/marimo/_server/ai/providers.py +++ b/marimo/_server/ai/providers.py @@ -29,6 +29,7 @@ require_vercel_ai_sdk_support, ) from marimo._server.ai.config import AnyProviderConfig +from marimo._server.ai.constants import ANTHROPIC_DEFAULT_MAX_TOKENS from marimo._server.ai.ids import AiModelId from marimo._server.ai.tools.tool_manager import get_tool_manager from marimo._server.ai.tools.types import ToolDefinition @@ -114,12 +115,12 @@ def create_provider(self, config: AnyProviderConfig) -> ProviderT: """Create a provider for the given config.""" @abstractmethod - def create_model(self, max_tokens: int) -> Model: + def create_model(self, max_tokens: int | None) -> Model: """Create a Pydantic AI model for the given max tokens.""" def create_agent( self, - max_tokens: int, + max_tokens: int | None, tools: list[ToolDefinition], system_prompt: str, ) -> Agent[None, DeferredToolRequests | str]: @@ -165,7 +166,7 @@ async def stream_completion( self, messages: list[ServerUIMessage], system_prompt: str, - max_tokens: int, + max_tokens: int | None, additional_tools: list[ToolDefinition], stream_options: StreamOptions | None = None, ) -> StreamingResponse: @@ -201,7 +202,7 @@ async def stream_text( user_prompt: str, messages: list[ServerUIMessage], system_prompt: str, - max_tokens: int, + max_tokens: int | None, additional_tools: list[ToolDefinition], ) -> AsyncGenerator[str]: """Return a stream of text from the given messages.""" @@ -294,13 +295,16 @@ def create_provider(self, config: AnyProviderConfig) -> PydanticGoogle: provider = PydanticGoogle() return provider - def create_model(self, max_tokens: int) -> GoogleModel: + def create_model(self, max_tokens: int | None) -> GoogleModel: from pydantic_ai.models.google import GoogleModel, GoogleModelSettings + settings: GoogleModelSettings = ( + {"max_tokens": max_tokens} if max_tokens is not None else {} + ) return GoogleModel( model_name=self.model, provider=self.provider, - settings=GoogleModelSettings(max_tokens=max_tokens), + settings=settings, ) @@ -399,16 +403,18 @@ def create_provider(self, config: AnyProviderConfig) -> PydanticOpenAI: client = self.get_openai_client(config) return PydanticOpenAI(openai_client=client) - def create_model(self, max_tokens: int) -> OpenAIResponsesModel: + def create_model(self, max_tokens: int | None) -> OpenAIResponsesModel: from pydantic_ai.models.openai import ( OpenAIResponsesModel, - OpenAIResponsesModelSettings, ) + settings: OpenAIResponsesModelSettings = ( + {"max_tokens": max_tokens} if max_tokens is not None else {} + ) return OpenAIResponsesModel( model_name=self.model, provider=self.provider, - settings=OpenAIResponsesModelSettings(max_tokens=max_tokens), + settings=settings, ) def _build_agent_settings(self, model: Model) -> ModelSettings | None: @@ -645,7 +651,7 @@ def _create_custom_provider( client = self.get_openai_client(config) return PydanticOpenAI(openai_client=client) - def create_model(self, max_tokens: int) -> OpenAIChatModel: + def create_model(self, max_tokens: int | None) -> OpenAIChatModel: """Default to OpenAIChatModel""" from pydantic_ai.models.openai import ( @@ -653,22 +659,24 @@ def create_model(self, max_tokens: int) -> OpenAIChatModel: OpenAIChatModelSettings, ) + settings: OpenAIChatModelSettings = ( + {"max_tokens": max_tokens} if max_tokens is not None else {} + ) return OpenAIChatModel( model_name=self.model, provider=self.provider, - settings=OpenAIChatModelSettings(max_tokens=max_tokens), + settings=settings, ) def create_agent( self, - max_tokens: int, + max_tokens: int | None, tools: list[ToolDefinition], system_prompt: str, ) -> Agent[None, DeferredToolRequests | str]: """Create a Pydantic AI agent""" from pydantic_ai import Agent, UserError from pydantic_ai.models import infer_model - from pydantic_ai.settings import ModelSettings try: model = infer_model( @@ -687,7 +695,9 @@ def create_agent( ) model = self.create_model(max_tokens) - agent_settings = ModelSettings(max_tokens=max_tokens) + agent_settings: ModelSettings = ( + {"max_tokens": max_tokens} if max_tokens is not None else {} + ) agent_settings.update(self._build_agent_settings(model) or {}) toolset, output_type = self._get_toolsets_and_output_type(tools) @@ -724,7 +734,7 @@ def create_provider(self, config: AnyProviderConfig) -> PydanticAnthropic: return PydanticAnthropic(api_key=config.api_key) - def create_model(self, max_tokens: int) -> Model: + def create_model(self, max_tokens: int | None) -> Model: from pydantic_ai.models.anthropic import ( AnthropicModel, AnthropicModelSettings, @@ -734,7 +744,11 @@ def create_model(self, max_tokens: int) -> Model: anthropic_model_profile, ) - settings: AnthropicModelSettings = {"max_tokens": max_tokens} + settings: AnthropicModelSettings = { + "max_tokens": max_tokens + if max_tokens is not None + else ANTHROPIC_DEFAULT_MAX_TOKENS + } # Anthropic extended thinking requires temperature=1; non-thinking # models keep our default coding temperature. Some adaptive-only @@ -812,16 +826,19 @@ def create_provider(self, config: AnyProviderConfig) -> PydanticBedrock: # For bedrock, the config sets the region name as the base_url return PydanticBedrock(region_name=config.base_url) - def create_model(self, max_tokens: int) -> BedrockConverseModel: + def create_model(self, max_tokens: int | None) -> BedrockConverseModel: from pydantic_ai.models.bedrock import ( BedrockConverseModel, BedrockModelSettings, ) + settings: BedrockModelSettings = ( + {"max_tokens": max_tokens} if max_tokens is not None else {} + ) return BedrockConverseModel( model_name=self.model, provider=self.provider, - settings=BedrockModelSettings(max_tokens=max_tokens), + settings=settings, ) diff --git a/marimo/_server/models/models.py b/marimo/_server/models/models.py index c21a652adac..8885e0f3eeb 100644 --- a/marimo/_server/models/models.py +++ b/marimo/_server/models/models.py @@ -320,7 +320,8 @@ class SaveAppConfigurationRequest(msgspec.Struct, rename="camel"): class SaveUserConfigurationRequest(msgspec.Struct, rename="camel"): - # deep partial user configuration + # deep partial user configuration; keys with value `None` are removed + # from the on-disk merged config (None-as-delete) config: dict[str, Any] diff --git a/tests/_config/test_manager.py b/tests/_config/test_manager.py index 222535981f4..7c57c0e6dee 100644 --- a/tests/_config/test_manager.py +++ b/tests/_config/test_manager.py @@ -3,7 +3,7 @@ import textwrap from collections.abc import Callable from functools import wraps -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar, cast from unittest.mock import patch import pytest @@ -164,6 +164,65 @@ def test_save_config_is_deterministic(tmp_path: Path) -> None: assert bytes_a == bytes_b +@restore_config +@patch("tomlkit.dump") +def test_save_config_none_deletes_key(mock_dump: Any) -> None: + """None-as-delete: sending {ai: {max_tokens: None}} removes the key.""" + mock_config = merge_default_config( + PartialMarimoConfig(ai={"max_tokens": 8192, "rules": "be terse"}) + ) + manager = UserConfigManager() + manager._load_config = lambda: mock_config + + manager.save_config( + cast( + PartialMarimoConfig, + {"ai": {"max_tokens": None}}, + ) + ) + + written = mock_dump.mock_calls[0][1][0] + assert "max_tokens" not in written["ai"] + # sibling key untouched + assert written["ai"]["rules"] == "be terse" + + +@restore_config +def test_save_config_with_none_does_not_raise(tmp_path: Path) -> None: + """Regression guard: TOML has no null type, so a None value reaching + tomlkit.dump raises ConvertError. _drop_none_values must strip it first, + making the real (unmocked) save succeed and omit the key.""" + config_path = tmp_path / "marimo.toml" + mock_config = merge_default_config( + PartialMarimoConfig(ai={"max_tokens": 8192, "rules": "be terse"}) + ) + manager = UserConfigManager() + manager._load_config = lambda: mock_config + + with patch.object( + manager, "get_config_path", return_value=str(config_path) + ): + manager.save_config( + cast(PartialMarimoConfig, {"ai": {"max_tokens": None}}) + ) + + contents = config_path.read_text() + assert "max_tokens" not in contents + assert "be terse" in contents + + +def test_drop_none_values_strips_nested_none() -> None: + from marimo._config.manager import _drop_none_values + + d: dict[str, Any] = { + "keep": 1, + "drop": None, + "nested": {"keep": "x", "drop": None}, + } + _drop_none_values(d) + assert d == {"keep": 1, "nested": {"keep": "x"}} + + @restore_config @patch("tomlkit.dump") def test_can_save_secrets(mock_dump: Any) -> None: diff --git a/tests/_server/ai/test_ai_config.py b/tests/_server/ai/test_ai_config.py index e9a3a1b71a5..a021fe5e590 100644 --- a/tests/_server/ai/test_ai_config.py +++ b/tests/_server/ai/test_ai_config.py @@ -24,7 +24,7 @@ get_edit_model, get_max_tokens, ) -from marimo._server.ai.constants import DEFAULT_MAX_TOKENS, DEFAULT_MODEL +from marimo._server.ai.constants import DEFAULT_MODEL from marimo._server.ai.tools.types import ToolDefinition from marimo._utils.http import HTTPStatus @@ -1081,7 +1081,7 @@ def test_get_max_tokens_from_config(self): assert result == 2048 def test_get_max_tokens_default_no_ai_config(self): - """Test getting default max tokens when no AI config.""" + """Test getting max tokens returns None when no AI config.""" config = cast( MarimoConfig, { @@ -1091,10 +1091,10 @@ def test_get_max_tokens_default_no_ai_config(self): result = get_max_tokens(config) - assert result == DEFAULT_MAX_TOKENS + assert result is None def test_get_max_tokens_default_no_max_tokens(self): - """Test getting default max tokens when max_tokens not specified.""" + """Test getting max tokens returns None when max_tokens not specified.""" config = cast( MarimoConfig, { @@ -1104,7 +1104,7 @@ def test_get_max_tokens_default_no_max_tokens(self): result = get_max_tokens(config) - assert result == DEFAULT_MAX_TOKENS + assert result is None def test_get_autocomplete_model(self) -> None: """Test get_autocomplete_model with new ai.models.autocomplete_model config.""" diff --git a/tests/_server/ai/test_providers.py b/tests/_server/ai/test_providers.py index 337dbef9fba..20722d8383e 100644 --- a/tests/_server/ai/test_providers.py +++ b/tests/_server/ai/test_providers.py @@ -452,3 +452,79 @@ async def test_completion_does_not_pass_redundant_instructions() -> None: # This asserts the duplication is gone assert instructions == "Test prompt" + + +@pytest.mark.skipif( + not DependencyManager.anthropic.has() + or not DependencyManager.pydantic_ai.has(), + reason="anthropic or pydantic_ai not installed", +) +def test_anthropic_applies_default_floor_when_max_tokens_none() -> None: + """When no max_tokens is configured, Anthropic still receives 32768.""" + from marimo._server.ai.constants import ANTHROPIC_DEFAULT_MAX_TOKENS + + config = AnyProviderConfig(api_key="test-key", base_url=None) + provider = AnthropicProvider("claude-sonnet-4-5", config) + model = provider.create_model(max_tokens=None) + assert ( + dict(model.settings).get("max_tokens") == ANTHROPIC_DEFAULT_MAX_TOKENS + ) + + +@pytest.mark.skipif( + not DependencyManager.anthropic.has() + or not DependencyManager.pydantic_ai.has(), + reason="anthropic or pydantic_ai not installed", +) +def test_anthropic_override_wins_over_default_floor() -> None: + """An explicit max_tokens overrides the Anthropic default floor.""" + config = AnyProviderConfig(api_key="test-key", base_url=None) + provider = AnthropicProvider("claude-sonnet-4-5", config) + model = provider.create_model(max_tokens=12345) + assert dict(model.settings).get("max_tokens") == 12345 + + +@pytest.mark.requires("pydantic_ai") +def test_openai_chat_omits_max_tokens_when_none() -> None: + """Non-Anthropic providers omit max_tokens entirely when not set, so + pydantic-ai falls through to the upstream provider's default.""" + config = AnyProviderConfig(api_key="test-key", base_url="http://test-url") + provider = OpenAIProvider("gpt-4", config) + model = provider.create_model(max_tokens=None) + assert "max_tokens" not in dict(model.settings) + + +@pytest.mark.requires("pydantic_ai") +def test_openai_chat_passes_explicit_max_tokens() -> None: + """Non-Anthropic providers pass through an explicit max_tokens.""" + config = AnyProviderConfig(api_key="test-key", base_url="http://test-url") + provider = OpenAIProvider("gpt-4", config) + model = provider.create_model(max_tokens=12345) + assert dict(model.settings).get("max_tokens") == 12345 + + +@pytest.mark.requires("pydantic_ai") +def test_custom_provider_agent_passes_explicit_max_tokens() -> None: + """The chat path builds the agent (not the model), so the agent's + model_settings must carry the explicit max_tokens.""" + config = AnyProviderConfig(api_key="test-key", base_url="http://test-url") + provider = get_completion_provider(config, "openrouter/gpt-4") + with patch("marimo._server.ai.providers.get_tool_manager") as mock_get_tm: + mock_get_tm.return_value = MagicMock() + agent = provider.create_agent( + max_tokens=12345, tools=[], system_prompt="x" + ) + assert dict(agent.model_settings or {}).get("max_tokens") == 12345 + + +@pytest.mark.requires("pydantic_ai") +def test_custom_provider_agent_omits_max_tokens_when_none() -> None: + """The chat path omits max_tokens from agent model_settings when unset.""" + config = AnyProviderConfig(api_key="test-key", base_url="http://test-url") + provider = get_completion_provider(config, "openrouter/gpt-4") + with patch("marimo._server.ai.providers.get_tool_manager") as mock_get_tm: + mock_get_tm.return_value = MagicMock() + agent = provider.create_agent( + max_tokens=None, tools=[], system_prompt="x" + ) + assert "max_tokens" not in dict(agent.model_settings or {})