1+ import json
2+ from datetime import datetime
13from enum import Enum
2- from typing import Annotated
4+ from typing import Annotated , Any
35
6+ import pytest
47from pydantic import BaseModel
5-
6- from workflowai .core .utils ._tools import tool_schema
8+ from zoneinfo import ZoneInfo
9+
10+ from workflowai .core .utils ._tools import _get_type_schema , tool_schema # pyright: ignore [reportPrivateUsage]
11+
12+
13+ class TestGetTypeSchema :
14+ class _BasicEnum (str , Enum ):
15+ A = "a"
16+ B = "b"
17+
18+ class _BasicModel (BaseModel ):
19+ a : int
20+ b : str
21+
22+ @pytest .mark .parametrize (
23+ ("param_type" , "value" ),
24+ [
25+ (int , 1 ),
26+ (float , 1.0 ),
27+ (bool , True ),
28+ (str , "test" ),
29+ (datetime , datetime .now (tz = ZoneInfo ("UTC" ))),
30+ (ZoneInfo , ZoneInfo ("UTC" )),
31+ (list [int ], [1 , 2 , 3 ]),
32+ (dict [str , int ], {"a" : 1 , "b" : 2 }),
33+ (_BasicEnum , _BasicEnum .A ),
34+ (_BasicModel , _BasicModel (a = 1 , b = "test" )),
35+ (list [_BasicModel ], [_BasicModel (a = 1 , b = "test" ), _BasicModel (a = 2 , b = "test2" )]),
36+ (tuple [int , str ], (1 , "test" )),
37+ ],
38+ )
39+ def test_get_type_schema (self , param_type : Any , value : Any ):
40+ schema = _get_type_schema (param_type )
41+ if schema .serializer is None or schema .deserializer is None :
42+ assert schema .serializer is None
43+ assert schema .deserializer is None
44+
45+ # Check that the value is serializable and deserializable with plain json
46+ assert json .loads (json .dumps (value )) == value
47+ return
48+
49+ serialized = schema .serializer (value )
50+ deserialized = schema .deserializer (serialized )
51+ assert deserialized == value
752
853
954class TestToolSchema :
@@ -17,6 +62,7 @@ def sample_func(
1762 age : int ,
1863 height : float ,
1964 is_active : bool ,
65+ date : datetime ,
2066 mode : TestMode = TestMode .FAST ,
2167 ) -> bool :
2268 """Sample function for testing"""
@@ -44,8 +90,12 @@ def sample_func(
4490 "type" : "string" ,
4591 "enum" : ["fast" , "slow" ],
4692 },
93+ "date" : {
94+ "type" : "string" ,
95+ "format" : "date-time" ,
96+ },
4797 },
48- "required" : ["name" , "age" , "height" , "is_active" ], # 'mode' is not required
98+ "required" : ["name" , "age" , "height" , "is_active" , "date" ], # 'mode' is not required
4999 }
50100 assert output_schema .schema == {
51101 "type" : "boolean" ,
@@ -123,3 +173,21 @@ def sample_func() -> TestModel: ...
123173 }
124174 assert output_schema .serializer is not None
125175 assert output_schema .serializer (TestModel (val = 10 )) == {"val" : 10 }
176+
177+ def test_with_datetime_in_input (self ):
178+ def sample_func (time : datetime ) -> str : ...
179+
180+ input_schema , _ = tool_schema (sample_func )
181+
182+ assert input_schema .deserializer is not None
183+ assert input_schema .deserializer ({"time" : "2024-01-01T12:00:00+00:00" }) == {
184+ "time" : datetime (
185+ 2024 ,
186+ 1 ,
187+ 1 ,
188+ 12 ,
189+ 0 ,
190+ 0 ,
191+ tzinfo = ZoneInfo ("UTC" ),
192+ ),
193+ }
0 commit comments