Skip to content

Commit 57baccc

Browse files
authored
Merge pull request #34 from WorkflowAI/guillaume/tools-feedback
Tools feedback
2 parents 9d106a4 + 31831d2 commit 57baccc

File tree

8 files changed

+112
-17
lines changed

8 files changed

+112
-17
lines changed

.vscode/launch.json

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"version": "0.2.0",
3+
"configurations": [
4+
{
5+
"name": "Python: Debug Tests",
6+
"type": "debugpy",
7+
"request": "launch",
8+
"program": "${file}",
9+
"purpose": ["debug-test"],
10+
"console": "integratedTerminal",
11+
"justMyCode": false
12+
}
13+
]
14+
}

examples/city_to_capital_task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class CityToCapitalTaskOutput(BaseModel):
2121
)
2222

2323

24-
@workflowai.task(schema_id=1)
24+
@workflowai.agent(schema_id=1)
2525
async def city_to_capital(task_input: CityToCapitalTaskInput) -> CityToCapitalTaskOutput: ...
2626

2727

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "workflowai"
3-
version = "0.6.0.dev3"
3+
version = "0.6.0.dev4"
44
description = ""
55
authors = ["Guillaume Aquilina <guillaume@workflowai.com>"]
66
readme = "README.md"

workflowai/core/client/agent.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232

3333

3434
class Agent(Generic[AgentInput, AgentOutput]):
35+
_DEFAULT_MAX_ITERATIONS = 10
36+
3537
def __init__(
3638
self,
3739
agent_id: str,
@@ -216,6 +218,8 @@ async def _build_run(
216218
run._agent = self # pyright: ignore [reportPrivateUsage]
217219

218220
if run.tool_call_requests:
221+
if current_iteration >= kwargs.get("max_iterations", self._DEFAULT_MAX_ITERATIONS):
222+
raise WorkflowAIError(error=BaseError(message="max tool iterations reached"), response=None)
219223
with_reply = await self._execute_tools(
220224
run_id=run.id,
221225
tool_call_requests=run.tool_call_requests,
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from typing import Any
2+
3+
from pydantic.json_schema import GenerateJsonSchema
4+
from typing_extensions import override
5+
6+
7+
class JsonSchemaGenerator(GenerateJsonSchema):
8+
"""A schema generator that simplifies the schemas generated by pydantic."""
9+
10+
@override
11+
def generate(self, *args: Any, **kwargs: Any):
12+
generated = super().generate(*args, **kwargs)
13+
# Remove the title from the schema
14+
generated.pop("title", None)
15+
return generated
16+
17+
@override
18+
def field_title_should_be_set(self, *args: Any, **kwargs: Any) -> bool:
19+
return False
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from pydantic import BaseModel
2+
3+
from workflowai.core.utils._schema_generator import JsonSchemaGenerator
4+
5+
6+
class TestJsonSchemaGenerator:
7+
def test_generate(self):
8+
class TestModel(BaseModel):
9+
name: str
10+
11+
schema = TestModel.model_json_schema(schema_generator=JsonSchemaGenerator)
12+
assert schema == {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]}

workflowai/core/utils/_tools.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from pydantic import BaseModel
66

7+
from workflowai.core.utils._schema_generator import JsonSchemaGenerator
8+
79
ToolFunction = Callable[..., Any]
810

911

@@ -57,26 +59,33 @@ def _get_type_schema(param_type: type) -> dict[str, Any]:
5759
if param_type is bool:
5860
return {"type": "boolean"}
5961

60-
if isinstance(param_type, BaseModel):
61-
return param_type.model_json_schema()
62+
if issubclass(param_type, BaseModel):
63+
return param_type.model_json_schema(by_alias=True, schema_generator=JsonSchemaGenerator)
6264

6365
raise ValueError(f"Unsupported type: {param_type}")
6466

6567

68+
def _schema_from_type_hint(param_type_hint: Any) -> dict[str, Any]:
69+
param_type = param_type_hint.__origin__ if hasattr(param_type_hint, "__origin__") else param_type_hint
70+
if not isinstance(param_type, type):
71+
raise ValueError(f"Unsupported type: {param_type}")
72+
73+
param_description = param_type_hint.__metadata__[0] if hasattr(param_type_hint, "__metadata__") else None
74+
param_schema = _get_type_schema(param_type)
75+
if param_description:
76+
param_schema["description"] = param_description
77+
78+
return param_schema
79+
80+
6681
def _build_input_schema(sig: inspect.Signature, type_hints: dict[str, Any]) -> dict[str, Any]:
6782
input_schema: dict[str, Any] = {"type": "object", "properties": {}, "required": []}
6883

6984
for param_name, param in sig.parameters.items():
7085
if param_name == "self":
7186
continue
7287

73-
param_type_hint = type_hints[param_name]
74-
param_type = param_type_hint.__origin__ if hasattr(param_type_hint, "__origin__") else param_type_hint
75-
param_description = param_type_hint.__metadata__[0] if hasattr(param_type_hint, "__metadata__") else None
76-
77-
param_schema = _get_type_schema(param_type) if isinstance(param_type, type) else {"type": "string"}
78-
if param_description is not None:
79-
param_schema["description"] = param_description
88+
param_schema = _schema_from_type_hint(type_hints[param_name])
8089

8190
if param.default is inspect.Parameter.empty:
8291
input_schema["required"].append(param_name)
@@ -91,9 +100,4 @@ def _build_output_schema(type_hints: dict[str, Any]) -> dict[str, Any]:
91100
if not return_type:
92101
raise ValueError("Return type annotation is required")
93102

94-
return_type_base = return_type.__origin__ if hasattr(return_type, "__origin__") else return_type
95-
96-
if not isinstance(return_type_base, type):
97-
raise ValueError(f"Unsupported return type: {return_type_base}")
98-
99-
return _get_type_schema(return_type_base)
103+
return _schema_from_type_hint(return_type)

workflowai/core/utils/_tools_test.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from enum import Enum
22
from typing import Annotated
33

4+
from pydantic import BaseModel
5+
46
from workflowai.core.utils._tools import tool_schema
57

68

@@ -70,3 +72,43 @@ def sample_method(self, value: int) -> str:
7072
assert schema.output_schema == {
7173
"type": "string",
7274
}
75+
76+
def test_with_base_model_in_input(self):
77+
class TestModel(BaseModel):
78+
name: str
79+
80+
def sample_func(model: TestModel) -> str: ...
81+
82+
schema = tool_schema(sample_func)
83+
84+
assert schema.input_schema == {
85+
"type": "object",
86+
"properties": {
87+
"model": {
88+
"properties": {
89+
"name": {
90+
"type": "string",
91+
},
92+
},
93+
"required": [
94+
"name",
95+
],
96+
"type": "object",
97+
},
98+
},
99+
"required": ["model"],
100+
}
101+
102+
def test_with_base_model_in_output(self):
103+
class TestModel(BaseModel):
104+
val: int
105+
106+
def sample_func() -> TestModel: ...
107+
108+
schema = tool_schema(sample_func)
109+
110+
assert schema.output_schema == {
111+
"type": "object",
112+
"properties": {"val": {"type": "integer"}},
113+
"required": ["val"],
114+
}

0 commit comments

Comments
 (0)