Skip to content

Commit 37a21b2

Browse files
committed
fix: handle more types in tool defs
1 parent 58b3abc commit 37a21b2

File tree

2 files changed

+91
-12
lines changed

2 files changed

+91
-12
lines changed

workflowai/core/utils/_tools.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import contextlib
2+
import datetime
13
import inspect
24
from enum import Enum
35
from typing import Any, Callable, NamedTuple, Optional, cast, get_type_hints
46

5-
from pydantic import BaseModel
7+
from pydantic import BaseModel, TypeAdapter
68

79
from workflowai.core.utils._schema_generator import JsonSchemaGenerator
810

@@ -15,6 +17,14 @@ class SchemaDeserializer(NamedTuple):
1517
deserializer: Optional[Callable[[Any], Any]] = None
1618

1719

20+
def _serialize_datetime(x: datetime.datetime) -> str:
21+
return x.isoformat()
22+
23+
24+
def _deserialize_datetime(x: str) -> datetime.datetime:
25+
return datetime.datetime.fromisoformat(x)
26+
27+
1828
def _get_type_schema(param_type: type):
1929
"""Convert a Python type to its corresponding JSON schema type.
2030
@@ -24,10 +34,6 @@ def _get_type_schema(param_type: type):
2434
Returns:
2535
A dictionary containing the JSON schema type definition
2636
"""
27-
if issubclass(param_type, Enum):
28-
if not issubclass(param_type, str):
29-
raise ValueError(f"Non string enums are not supported: {param_type}")
30-
return SchemaDeserializer({"type": "string", "enum": [e.value for e in param_type]})
3137

3238
if param_type is str:
3339
return SchemaDeserializer({"type": "string"})
@@ -41,11 +47,33 @@ def _get_type_schema(param_type: type):
4147
if param_type is bool:
4248
return SchemaDeserializer({"type": "boolean"})
4349

44-
if issubclass(param_type, BaseModel):
50+
if param_type is datetime.datetime:
51+
return SchemaDeserializer(
52+
{"type": "string", "format": "date-time"},
53+
serializer=_serialize_datetime,
54+
deserializer=_deserialize_datetime,
55+
)
56+
57+
if inspect.isclass(param_type):
58+
if issubclass(param_type, BaseModel):
59+
return SchemaDeserializer(
60+
schema=param_type.model_json_schema(by_alias=True, schema_generator=JsonSchemaGenerator),
61+
serializer=lambda x: cast(BaseModel, x).model_dump(mode="json"), # pyright: ignore [reportUnknownLambdaType]
62+
deserializer=param_type.model_validate,
63+
)
64+
65+
if issubclass(param_type, Enum):
66+
if not issubclass(param_type, str):
67+
raise ValueError(f"Non string enums are not supported: {param_type}")
68+
return SchemaDeserializer({"type": "string", "enum": [e.value for e in param_type]})
69+
70+
# Attempting to build a type adapter with pydantic
71+
with contextlib.suppress(Exception):
72+
adapter = TypeAdapter[Any](param_type)
4573
return SchemaDeserializer(
46-
schema=param_type.model_json_schema(by_alias=True, schema_generator=JsonSchemaGenerator),
47-
serializer=lambda x: cast(BaseModel, x).model_dump(mode="json"), # pyright: ignore [reportUnknownLambdaType]
48-
deserializer=param_type.model_validate,
74+
schema=adapter.json_schema(),
75+
deserializer=adapter.validate_python, # pyright: ignore [reportUnknownLambdaType]
76+
serializer=lambda x: adapter.dump_python(x, mode="json"), # pyright: ignore [reportUnknownLambdaType]
4977
)
5078

5179
raise ValueError(f"Unsupported type: {param_type}")

workflowai/core/utils/_tools_test.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,42 @@
1+
import json
2+
from datetime import datetime
13
from enum import Enum
2-
from typing import Annotated
4+
from typing import Annotated, Any
35

6+
import pytest
47
from pydantic import BaseModel
5-
6-
from workflowai.core.utils._tools import tool_schema
8+
from zoneinfo import ZoneInfo
9+
10+
from workflowai.core.utils._tools import _get_type_schema, tool_schema # pyright: ignore [reportPrivateUsage]
11+
12+
13+
class TestGetTypeSchema:
14+
@pytest.mark.parametrize(
15+
("param_type", "value"),
16+
[
17+
(int, 1),
18+
(float, 1.0),
19+
(bool, True),
20+
(str, "test"),
21+
(datetime, datetime.now(tz=ZoneInfo("UTC"))),
22+
(ZoneInfo, ZoneInfo("UTC")),
23+
(list[int], [1, 2, 3]),
24+
(dict[str, int], {"a": 1, "b": 2}),
25+
],
26+
)
27+
def test_get_type_schema(self, param_type: Any, value: Any):
28+
schema = _get_type_schema(param_type)
29+
if schema.serializer is None or schema.deserializer is None:
30+
assert schema.serializer is None
31+
assert schema.deserializer is None
32+
33+
# Check that the value is serializable and deserializable with plain json
34+
assert json.loads(json.dumps(value)) == value
35+
return
36+
37+
serialized = schema.serializer(value)
38+
deserialized = schema.deserializer(serialized)
39+
assert deserialized == value
740

841

942
class TestToolSchema:
@@ -123,3 +156,21 @@ def sample_func() -> TestModel: ...
123156
}
124157
assert output_schema.serializer is not None
125158
assert output_schema.serializer(TestModel(val=10)) == {"val": 10}
159+
160+
def test_with_datetime_in_input(self):
161+
def sample_func(time: datetime) -> str: ...
162+
163+
input_schema, _ = tool_schema(sample_func)
164+
165+
assert input_schema.deserializer is not None
166+
assert input_schema.deserializer({"time": "2024-01-01T12:00:00+00:00"}) == {
167+
"time": datetime(
168+
2024,
169+
1,
170+
1,
171+
12,
172+
0,
173+
0,
174+
tzinfo=ZoneInfo("UTC"),
175+
),
176+
}

0 commit comments

Comments
 (0)