Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions frontend/src/components/app-config/ai-config.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -1253,6 +1253,13 @@ export const AiAssistConfig: React.FC<AiConfigProps> = ({
config,
onSubmit,
}) => {
// Tracked locally rather than derived from the field value so that clearing
// the input (a transient empty value, which commits `null`) does not disable
// the input mid-edit and force the user to re-tick the Override checkbox.
const [maxTokensEnabled, setMaxTokensEnabled] = useState(
config.ai?.max_tokens != null,
);

return (
<SettingGroup>
<SettingSubtitle>AI Assistant</SettingSubtitle>
Expand All @@ -1279,6 +1286,71 @@ export const AiAssistConfig: React.FC<AiConfigProps> = ({
)}
/>

<FormField
control={form.control}
name="ai.max_tokens"
render={({ field }) => {
return (
<div className="flex flex-col gap-y-1">
<div className="flex items-center gap-x-2">
<FormItem className={formItemClasses}>
<FormLabel className="font-normal">
Max output tokens
</FormLabel>
<FormControl>
<Input
data-testid="ai-max-tokens-input"
type="number"
min={1}
disabled={!maxTokensEnabled}
className="w-28 h-6"
value={field.value ?? (maxTokensEnabled ? "" : 32768)}
onChange={(e) => {
const n = Number.parseInt(e.target.value, 10);
field.onChange(Number.isFinite(n) && n > 0 ? n : null);
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
}}
/>
</FormControl>
</FormItem>
<FormItem className={formItemClasses}>
<Checkbox
data-testid="ai-max-tokens-checkbox"
checked={maxTokensEnabled}
onCheckedChange={(checked) => {
const isChecked = checked === true;
setMaxTokensEnabled(isChecked);
// null signals delete to the server; cast because
// UserConfig (OpenAPI-derived) types max_tokens as
// `number | undefined`, but zod accepts `null`.
const next = (
isChecked ? (field.value ?? 32768) : null
) as number | undefined;
// shouldDirty: true forces RHF to keep this in
// dirtyFields even when `next` happens to equal the
// form's defaultValue (e.g. untick → tick when disk
// started with 32768). Otherwise getDirtyValues
// would skip it and the save body would be empty.
form.setValue("ai.max_tokens", next, {
shouldDirty: true,
shouldTouch: true,
});
onSubmit(form.getValues());
}}
Comment thread
kirangadhave marked this conversation as resolved.
/>
<FormLabel className="font-normal">Override</FormLabel>
</FormItem>
</div>

<FormDescription>
Each provider sets its own max output tokens (Anthropic uses a
recommended default). Adjust to control costs or enable more
output.
</FormDescription>
</div>
);
}}
/>

<FormErrorsBanner />
<ModelSelector
label="Chat Model"
Expand Down
2 changes: 1 addition & 1 deletion frontend/src/components/app-config/user-config-form.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -1355,7 +1355,7 @@ export const UserConfigForm: React.FC = () => {
<Form {...form}>
<form
ref={formElement}
onChange={form.handleSubmit(onSubmit)}
onChange={form.handleSubmit((values) => onSubmit(values))}
className="flex text-pretty overflow-hidden"
>
<Tabs
Expand Down
1 change: 1 addition & 0 deletions frontend/src/core/config/config-schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ export const UserConfigSchema = z
ai: z
.looseObject({
rules: z.string().prefault(""),
max_tokens: z.number().int().positive().nullable().optional(),
mode: z.enum(COPILOT_MODES).prefault("manual"),
inline_tooltip: z.boolean().prefault(false),
open_ai: AiConfigSchema.optional(),
Expand Down
3 changes: 2 additions & 1 deletion marimo/_config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,8 @@ def merge_default_config(


def merge_config(
config: MarimoConfig, new_config: PartialMarimoConfig | MarimoConfig
config: MarimoConfig,
new_config: PartialMarimoConfig | MarimoConfig,
) -> MarimoConfig:
"""Merge a user configuration with a new configuration. The new config
will take precedence over the default config.
Expand Down
15 changes: 15 additions & 0 deletions marimo/_config/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,11 @@ def save_config(
# Merge the current config with the new config
current_config = self._load_config()
merged = merge_config(current_config, config)
# None-as-delete: any key whose merged value is None (typically because
# the incoming config explicitly sent null) is removed from disk. Lets
# the UI clear optional scalars (e.g. ai.max_tokens) without a separate
# delete primitive.
_drop_none_values(cast(dict[str, Any], merged))
Comment thread
kirangadhave marked this conversation as resolved.
Comment on lines +403 to +407
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should note this for bug bash. We had issues in the past where config changes on the frontend didn't save, this whole function is a bit fragile.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add it to notes for next release

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think we need more tests around this?

Copy link
Copy Markdown
Collaborator

@Light2Dark Light2Dark May 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe..

The previous issues was if you add a custom model, then refresh, it wouldn't be saved.
Same if you ticked certain models in the AI models dropdown, and then refreshed, it would disappear.
Same with deleting a custom model.

I would test ^ during bug bash


with open(config_path, "w", encoding="utf-8") as f:
tomlkit.dump(merged, f, sort_keys=True)
Expand Down Expand Up @@ -459,3 +464,13 @@ def get_config(self, *, hide_secrets: bool = True) -> PartialMarimoConfig:
if hide_secrets:
return mask_secrets_partial(self.override_config)
return self.override_config


def _drop_none_values(d: dict[str, Any]) -> None:
"""Recursively remove keys whose value is None, in place."""
for key in list(d):
v = d[key]
if v is None:
del d[key]
elif isinstance(v, dict):
_drop_none_values(v)
10 changes: 4 additions & 6 deletions marimo/_server/ai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
MarimoConfig,
PartialMarimoConfig,
)
from marimo._server.ai.constants import DEFAULT_MAX_TOKENS, DEFAULT_MODEL
from marimo._server.ai.constants import DEFAULT_MODEL
from marimo._server.ai.ids import AiModelId
from marimo._server.ai.tools.tool_manager import get_tool_manager
from marimo._server.ai.tools.types import ToolDefinition
Expand Down Expand Up @@ -346,11 +346,9 @@ def get_autocomplete_model(
)


def get_max_tokens(config: MarimoConfig) -> int:
if "ai" not in config:
return DEFAULT_MAX_TOKENS
if "max_tokens" not in config["ai"]:
return DEFAULT_MAX_TOKENS
def get_max_tokens(config: MarimoConfig) -> int | None:
if "ai" not in config or "max_tokens" not in config["ai"]:
return None
return config["ai"]["max_tokens"]


Expand Down
2 changes: 1 addition & 1 deletion marimo/_server/ai/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2026 Marimo. All rights reserved.
from __future__ import annotations

DEFAULT_MAX_TOKENS = 4096
ANTHROPIC_DEFAULT_MAX_TOKENS = 32768
Comment thread
Light2Dark marked this conversation as resolved.
DEFAULT_MODEL = "openai/gpt-4o"
53 changes: 35 additions & 18 deletions marimo/_server/ai/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
require_vercel_ai_sdk_support,
)
from marimo._server.ai.config import AnyProviderConfig
from marimo._server.ai.constants import ANTHROPIC_DEFAULT_MAX_TOKENS
from marimo._server.ai.ids import AiModelId
from marimo._server.ai.tools.tool_manager import get_tool_manager
from marimo._server.ai.tools.types import ToolDefinition
Expand Down Expand Up @@ -114,12 +115,12 @@ def create_provider(self, config: AnyProviderConfig) -> ProviderT:
"""Create a provider for the given config."""

@abstractmethod
def create_model(self, max_tokens: int) -> Model:
def create_model(self, max_tokens: int | None) -> Model:
"""Create a Pydantic AI model for the given max tokens."""

def create_agent(
self,
max_tokens: int,
max_tokens: int | None,
tools: list[ToolDefinition],
system_prompt: str,
) -> Agent[None, DeferredToolRequests | str]:
Expand Down Expand Up @@ -165,7 +166,7 @@ async def stream_completion(
self,
messages: list[ServerUIMessage],
system_prompt: str,
max_tokens: int,
max_tokens: int | None,
additional_tools: list[ToolDefinition],
stream_options: StreamOptions | None = None,
) -> StreamingResponse:
Expand Down Expand Up @@ -201,7 +202,7 @@ async def stream_text(
user_prompt: str,
messages: list[ServerUIMessage],
system_prompt: str,
max_tokens: int,
max_tokens: int | None,
additional_tools: list[ToolDefinition],
) -> AsyncGenerator[str]:
"""Return a stream of text from the given messages."""
Expand Down Expand Up @@ -294,13 +295,16 @@ def create_provider(self, config: AnyProviderConfig) -> PydanticGoogle:
provider = PydanticGoogle()
return provider

def create_model(self, max_tokens: int) -> GoogleModel:
def create_model(self, max_tokens: int | None) -> GoogleModel:
from pydantic_ai.models.google import GoogleModel, GoogleModelSettings

settings: GoogleModelSettings = (
{"max_tokens": max_tokens} if max_tokens is not None else {}
)
return GoogleModel(
model_name=self.model,
provider=self.provider,
settings=GoogleModelSettings(max_tokens=max_tokens),
settings=settings,
)


Expand Down Expand Up @@ -399,16 +403,18 @@ def create_provider(self, config: AnyProviderConfig) -> PydanticOpenAI:
client = self.get_openai_client(config)
return PydanticOpenAI(openai_client=client)

def create_model(self, max_tokens: int) -> OpenAIResponsesModel:
def create_model(self, max_tokens: int | None) -> OpenAIResponsesModel:
from pydantic_ai.models.openai import (
OpenAIResponsesModel,
OpenAIResponsesModelSettings,
)

settings: OpenAIResponsesModelSettings = (
{"max_tokens": max_tokens} if max_tokens is not None else {}
)
return OpenAIResponsesModel(
model_name=self.model,
provider=self.provider,
settings=OpenAIResponsesModelSettings(max_tokens=max_tokens),
settings=settings,
)

def _build_agent_settings(self, model: Model) -> ModelSettings | None:
Expand Down Expand Up @@ -645,30 +651,32 @@ def _create_custom_provider(
client = self.get_openai_client(config)
return PydanticOpenAI(openai_client=client)

def create_model(self, max_tokens: int) -> OpenAIChatModel:
def create_model(self, max_tokens: int | None) -> OpenAIChatModel:
"""Default to OpenAIChatModel"""

from pydantic_ai.models.openai import (
OpenAIChatModel,
OpenAIChatModelSettings,
)

settings: OpenAIChatModelSettings = (
{"max_tokens": max_tokens} if max_tokens is not None else {}
)
return OpenAIChatModel(
model_name=self.model,
provider=self.provider,
settings=OpenAIChatModelSettings(max_tokens=max_tokens),
settings=settings,
)

def create_agent(
self,
max_tokens: int,
max_tokens: int | None,
tools: list[ToolDefinition],
system_prompt: str,
) -> Agent[None, DeferredToolRequests | str]:
"""Create a Pydantic AI agent"""
from pydantic_ai import Agent, UserError
from pydantic_ai.models import infer_model
from pydantic_ai.settings import ModelSettings

try:
model = infer_model(
Expand All @@ -687,7 +695,9 @@ def create_agent(
)
model = self.create_model(max_tokens)

agent_settings = ModelSettings(max_tokens=max_tokens)
agent_settings: ModelSettings = (
{"max_tokens": max_tokens} if max_tokens is not None else {}
)
agent_settings.update(self._build_agent_settings(model) or {})

toolset, output_type = self._get_toolsets_and_output_type(tools)
Expand Down Expand Up @@ -724,7 +734,7 @@ def create_provider(self, config: AnyProviderConfig) -> PydanticAnthropic:

return PydanticAnthropic(api_key=config.api_key)

def create_model(self, max_tokens: int) -> Model:
def create_model(self, max_tokens: int | None) -> Model:
from pydantic_ai.models.anthropic import (
AnthropicModel,
AnthropicModelSettings,
Expand All @@ -734,7 +744,11 @@ def create_model(self, max_tokens: int) -> Model:
anthropic_model_profile,
)

settings: AnthropicModelSettings = {"max_tokens": max_tokens}
settings: AnthropicModelSettings = {
"max_tokens": max_tokens
if max_tokens is not None
else ANTHROPIC_DEFAULT_MAX_TOKENS
}

# Anthropic extended thinking requires temperature=1; non-thinking
# models keep our default coding temperature. Some adaptive-only
Expand Down Expand Up @@ -812,16 +826,19 @@ def create_provider(self, config: AnyProviderConfig) -> PydanticBedrock:
# For bedrock, the config sets the region name as the base_url
return PydanticBedrock(region_name=config.base_url)

def create_model(self, max_tokens: int) -> BedrockConverseModel:
def create_model(self, max_tokens: int | None) -> BedrockConverseModel:
from pydantic_ai.models.bedrock import (
BedrockConverseModel,
BedrockModelSettings,
)

settings: BedrockModelSettings = (
{"max_tokens": max_tokens} if max_tokens is not None else {}
)
return BedrockConverseModel(
model_name=self.model,
provider=self.provider,
settings=BedrockModelSettings(max_tokens=max_tokens),
settings=settings,
)


Expand Down
3 changes: 2 additions & 1 deletion marimo/_server/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,8 @@ class SaveAppConfigurationRequest(msgspec.Struct, rename="camel"):


class SaveUserConfigurationRequest(msgspec.Struct, rename="camel"):
# deep partial user configuration
# deep partial user configuration; keys with value `None` are removed
# from the on-disk merged config (None-as-delete)
config: dict[str, Any]


Expand Down
Loading
Loading