44
55from pydantic import BaseModel
66
7+ from workflowai .core .utils ._schema_generator import JsonSchemaGenerator
8+
79ToolFunction = Callable [..., Any ]
810
911
@@ -57,26 +59,33 @@ def _get_type_schema(param_type: type) -> dict[str, Any]:
5759 if param_type is bool :
5860 return {"type" : "boolean" }
5961
60- if isinstance (param_type , BaseModel ):
61- return param_type .model_json_schema ()
62+ if issubclass (param_type , BaseModel ):
63+ return param_type .model_json_schema (by_alias = True , schema_generator = JsonSchemaGenerator )
6264
6365 raise ValueError (f"Unsupported type: { param_type } " )
6466
6567
68+ def _schema_from_type_hint (param_type_hint : Any ) -> dict [str , Any ]:
69+ param_type = param_type_hint .__origin__ if hasattr (param_type_hint , "__origin__" ) else param_type_hint
70+ if not isinstance (param_type , type ):
71+ raise ValueError (f"Unsupported type: { param_type } " )
72+
73+ param_description = param_type_hint .__metadata__ [0 ] if hasattr (param_type_hint , "__metadata__" ) else None
74+ param_schema = _get_type_schema (param_type )
75+ if param_description :
76+ param_schema ["description" ] = param_description
77+
78+ return param_schema
79+
80+
6681def _build_input_schema (sig : inspect .Signature , type_hints : dict [str , Any ]) -> dict [str , Any ]:
6782 input_schema : dict [str , Any ] = {"type" : "object" , "properties" : {}, "required" : []}
6883
6984 for param_name , param in sig .parameters .items ():
7085 if param_name == "self" :
7186 continue
7287
73- param_type_hint = type_hints [param_name ]
74- param_type = param_type_hint .__origin__ if hasattr (param_type_hint , "__origin__" ) else param_type_hint
75- param_description = param_type_hint .__metadata__ [0 ] if hasattr (param_type_hint , "__metadata__" ) else None
76-
77- param_schema = _get_type_schema (param_type ) if isinstance (param_type , type ) else {"type" : "string" }
78- if param_description is not None :
79- param_schema ["description" ] = param_description
88+ param_schema = _schema_from_type_hint (type_hints [param_name ])
8089
8190 if param .default is inspect .Parameter .empty :
8291 input_schema ["required" ].append (param_name )
@@ -91,9 +100,4 @@ def _build_output_schema(type_hints: dict[str, Any]) -> dict[str, Any]:
91100 if not return_type :
92101 raise ValueError ("Return type annotation is required" )
93102
94- return_type_base = return_type .__origin__ if hasattr (return_type , "__origin__" ) else return_type
95-
96- if not isinstance (return_type_base , type ):
97- raise ValueError (f"Unsupported return type: { return_type_base } " )
98-
99- return _get_type_schema (return_type_base )
103+ return _schema_from_type_hint (return_type )
0 commit comments