66from pydantic import BaseModel , ValidationError
77from typing_extensions import Unpack
88
9- from workflowai .core ._common_types import BaseRunParams , OutputValidator , VersionRunParams
9+ from workflowai .core ._common_types import (
10+ BaseRunParams ,
11+ OtherRunParams ,
12+ OutputValidator ,
13+ VersionRunParams ,
14+ )
1015from workflowai .core .client ._api import APIClient
1116from workflowai .core .client ._models import (
1217 CompletionsResponse ,
2732 global_default_version_reference ,
2833)
2934from workflowai .core .domain .completion import Completion
30- from workflowai .core .domain .errors import BaseError , WorkflowAIError
35+ from workflowai .core .domain .errors import BaseError , MaxTurnsReachedError , WorkflowAIError
3136from workflowai .core .domain .run import Run
3237from workflowai .core .domain .task import AgentInput , AgentOutput
3338from workflowai .core .domain .tool import Tool
@@ -83,7 +88,7 @@ class MyOutput(BaseModel):
8388 ```
8489 """
8590
86- _DEFAULT_MAX_ITERATIONS = 10
91+ _DEFAULT_MAX_TURNS = 10
8792
8893 def __init__ (
8994 self ,
@@ -94,6 +99,7 @@ def __init__(
9499 schema_id : Optional [int ] = None ,
95100 version : Optional [VersionReference ] = None ,
96101 tools : Optional [Iterable [Callable [..., Any ]]] = None ,
102+ ** kwargs : Unpack [OtherRunParams ],
97103 ):
98104 self .agent_id = agent_id
99105 self .schema_id = schema_id
@@ -104,6 +110,7 @@ def __init__(
104110 self ._tools = self .build_tools (tools ) if tools else None
105111
106112 self ._default_validator = default_validator (output_cls )
113+ self ._other_run_params = kwargs
107114
108115 @classmethod
109116 def build_tools (cls , tools : Iterable [Callable [..., Any ]]):
@@ -180,6 +187,13 @@ def _sanitize_version(self, params: VersionRunParams) -> Union[str, int, dict[st
180187 dumped ["temperature" ] = combined .temperature
181188 return dumped
182189
190+ def _get_run_param (self , key : str , params : OtherRunParams , default : Any = None ) -> Any :
191+ if key in params :
192+ return params [key ] # pyright: ignore [reportUnknownVariableType]
193+ if key in self ._other_run_params :
194+ return self ._other_run_params [key ] # pyright: ignore [reportUnknownVariableType]
195+ return default
196+
183197 async def _prepare_run (self , agent_input : AgentInput , stream : bool , ** kwargs : Unpack [RunParams [AgentOutput ]]):
184198 schema_id = self .schema_id
185199 if not schema_id :
@@ -192,15 +206,14 @@ async def _prepare_run(self, agent_input: AgentInput, stream: bool, **kwargs: Un
192206 task_input = agent_input .model_dump (by_alias = True ),
193207 version = version ,
194208 stream = stream ,
195- use_cache = kwargs . get ("use_cache" ),
209+ use_cache = self . _get_run_param ("use_cache" , kwargs ),
196210 metadata = kwargs .get ("metadata" ),
197- labels = kwargs .get ("labels" ),
198211 )
199212
200213 route = f"/v1/_/agents/{ self .agent_id } /schemas/{ self .schema_id } /run"
201214 should_retry , wait_for_exception = build_retryable_wait (
202- kwargs . get ("max_retry_delay" , 60 ),
203- kwargs . get ("max_retry_count" , 1 ),
215+ self . _get_run_param ("max_retry_delay" , kwargs , 60 ),
216+ self . _get_run_param ("max_retry_count" , kwargs , 1 ),
204217 )
205218 return self ._PreparedRun (request , route , should_retry , wait_for_exception , schema_id )
206219
@@ -227,8 +240,8 @@ async def _prepare_reply(
227240 )
228241 route = f"/v1/_/agents/{ self .agent_id } /runs/{ run_id } /reply"
229242 should_retry , wait_for_exception = build_retryable_wait (
230- kwargs . get ("max_retry_delay" , 60 ),
231- kwargs . get ("max_retry_count" , 1 ),
243+ self . _get_run_param ("max_retry_delay" , kwargs , 60 ),
244+ self . _get_run_param ("max_retry_count" , kwargs , 1 ),
232245 )
233246
234247 return self ._PreparedRun (request , route , should_retry , wait_for_exception , self .schema_id )
@@ -324,8 +337,14 @@ async def _build_run(
324337 run = self ._build_run_no_tools (chunk , schema_id , validator )
325338
326339 if run .tool_call_requests :
327- if current_iteration >= kwargs .get ("max_iterations" , self ._DEFAULT_MAX_ITERATIONS ):
328- raise WorkflowAIError (error = BaseError (message = "max tool iterations reached" ), response = None )
340+ if current_iteration >= self ._get_run_param ("max_turns" , kwargs , self ._DEFAULT_MAX_TURNS ):
341+ if self ._get_run_param ("max_turns_raises" , kwargs , default = True ):
342+ raise MaxTurnsReachedError (
343+ error = BaseError (message = "max tool iterations reached" ),
344+ response = None ,
345+ tool_call_requests = run .tool_call_requests ,
346+ )
347+ return run
329348 with_reply = await self ._execute_tools (
330349 run_id = run .id ,
331350 tool_call_requests = run .tool_call_requests ,
@@ -368,7 +387,9 @@ async def run(
368387 max_retry_delay (Optional[float], optional): The maximum delay between retries in milliseconds.
369388 Defaults to 60000.
370389 max_retry_count (Optional[float], optional): The maximum number of retry attempts. Defaults to 1.
371- max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles. Defaults to 10.
390+ max_turns (Optional[int], optional): Maximum number of tool iteration cycles. Defaults to 10.
391+ max_turns_raises (Optional[bool], optional): Whether to raise an error when the maximum number of turns is
392+ reached. Defaults to True.
372393 validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the output.
373394
374395 Returns:
@@ -385,7 +406,7 @@ async def run(
385406 res ,
386407 prepared_run .schema_id ,
387408 validator ,
388- current_iteration = 0 ,
409+ current_iteration = 1 ,
389410 # TODO[test]: add test with custom validator
390411 ** new_kwargs ,
391412 )
@@ -424,7 +445,7 @@ async def stream(
424445 max_retry_delay (Optional[float], optional): The maximum delay between retries in milliseconds.
425446 Defaults to 60000.
426447 max_retry_count (Optional[float], optional): The maximum number of retry attempts. Defaults to 1.
427- max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles. Defaults to 10.
448+ max_turns (Optional[int], optional): Maximum number of tool iteration cycles. Defaults to 10.
428449 validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the output.
429450
430451 Returns:
0 commit comments