Skip to content

Commit 4c3448f

Browse files
authored
Merge pull request #72 from WorkflowAI/guillaume/fix-warning
Remove warnings from optional models
2 parents fd8a163 + ae21822 commit 4c3448f

File tree

4 files changed

+101
-11
lines changed

4 files changed

+101
-11
lines changed

workflowai/core/client/agent.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,7 @@ async def reply(
472472
user_message: Optional[str] = None,
473473
tool_results: Optional[Iterable[ToolCallResult]] = None,
474474
current_iteration: int = 0,
475+
max_retries: int = 2,
475476
**kwargs: Unpack[RunParams[AgentOutput]],
476477
):
477478
"""Reply to a run to provide additional information or context.
@@ -489,7 +490,18 @@ async def reply(
489490
prepared_run = await self._prepare_reply(run_id, user_message, tool_results, stream=False, **kwargs)
490491
validator, new_kwargs = self._sanitize_validator(kwargs, self._default_validator)
491492

492-
res = await self.api.post(prepared_run.route, prepared_run.request, returns=RunResponse, run=True)
493+
async def _with_retries():
494+
err: Optional[WorkflowAIError] = None
495+
for _ in range(max_retries):
496+
try:
497+
return await self.api.post(prepared_run.route, prepared_run.request, returns=RunResponse, run=True)
498+
except WorkflowAIError as e: # noqa: PERF203
499+
if e.code != "object_not_found":
500+
raise e
501+
err = e
502+
raise err or RuntimeError("This should never raise")
503+
504+
res = await _with_retries()
493505
return await self._build_run(
494506
res,
495507
prepared_run.schema_id,

workflowai/core/client/agent_test.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,3 +1078,75 @@ async def test_stream_validation_final_error(
10781078

10791079
assert e.value.partial_output == {"message": 1}
10801080
assert e.value.run_id == "1"
1081+
1082+
1083+
class TestReply:
1084+
async def test_reply_success(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]):
1085+
httpx_mock.add_response(
1086+
url="http://localhost:8000/v1/_/agents/123/runs/1/reply",
1087+
json=fixtures_json("task_run.json"),
1088+
)
1089+
reply = await agent.reply(run_id="1", user_message="test message")
1090+
assert reply.output.message == "Austin"
1091+
1092+
assert len(httpx_mock.get_requests()) == 1
1093+
1094+
async def test_reply_first_404(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]):
1095+
"""Check that we retry once if the run is not found"""
1096+
1097+
httpx_mock.add_response(
1098+
url="http://localhost:8000/v1/_/agents/123/runs/1/reply",
1099+
status_code=404,
1100+
json={
1101+
"error": {
1102+
"code": "object_not_found",
1103+
},
1104+
},
1105+
)
1106+
1107+
httpx_mock.add_response(
1108+
url="http://localhost:8000/v1/_/agents/123/runs/1/reply",
1109+
json=fixtures_json("task_run.json"),
1110+
)
1111+
1112+
reply = await agent.reply(run_id="1", user_message="test message")
1113+
assert reply.output.message == "Austin"
1114+
1115+
assert len(httpx_mock.get_requests()) == 2
1116+
1117+
async def test_reply_not_not_found_error(
1118+
self,
1119+
httpx_mock: HTTPXMock,
1120+
agent: Agent[HelloTaskInput, HelloTaskOutput],
1121+
):
1122+
"""Check that we raise the error if it's not a 404"""
1123+
httpx_mock.add_response(
1124+
url="http://localhost:8000/v1/_/agents/123/runs/1/reply",
1125+
status_code=400,
1126+
json={
1127+
"error": {
1128+
"code": "whatever",
1129+
},
1130+
},
1131+
)
1132+
with pytest.raises(WorkflowAIError) as e:
1133+
await agent.reply(run_id="1", user_message="test message")
1134+
assert e.value.code == "whatever"
1135+
assert len(httpx_mock.get_requests()) == 1
1136+
1137+
async def test_reply_multiple_retries(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]):
1138+
"""Check that we retry once if the run is not found"""
1139+
httpx_mock.add_response(
1140+
url="http://localhost:8000/v1/_/agents/123/runs/1/reply",
1141+
status_code=404,
1142+
json={
1143+
"error": {
1144+
"code": "object_not_found",
1145+
},
1146+
},
1147+
is_reusable=True,
1148+
)
1149+
with pytest.raises(WorkflowAIError) as e:
1150+
await agent.reply(run_id="1", user_message="test message")
1151+
assert e.value.code == "object_not_found"
1152+
assert len(httpx_mock.get_requests()) == 2

workflowai/core/utils/_pydantic.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from collections.abc import Mapping, Sequence
23
from typing import Any, TypeVar, get_args, get_origin
34

@@ -25,9 +26,11 @@ def _copy_field_info(field_info: FieldInfo, **overrides: Any):
2526
certain values.
2627
"""
2728

29+
_excluded = {"annotation", "required"}
30+
2831
kwargs = overrides
2932
for k, v in field_info.__repr_args__():
30-
if k in kwargs or not k:
33+
if k in kwargs or not k or k in _excluded:
3134
continue
3235
kwargs[k] = v
3336

@@ -79,7 +82,6 @@ def partial_model(base: type[BM]) -> type[BM]:
7982
overrides: dict[str, Any] = {}
8083
try:
8184
annotation = _optional_annotation(field.annotation)
82-
overrides["annotation"] = annotation
8385
overrides["default"] = _default_value_from_annotation(annotation)
8486
except Exception: # noqa: BLE001
8587
logger.debug("Failed to make annotation optional", exc_info=True)
@@ -95,10 +97,12 @@ def custom_eq(o1: BM, o2: Any):
9597
return False
9698
return o1.model_dump() == o2.model_dump()
9799

98-
return create_model( # pyright: ignore [reportCallIssue, reportUnknownVariableType]
99-
f"Partial{base.__name__}",
100-
__base__=base,
101-
__eq__=custom_eq,
102-
__hash__=base.__hash__,
103-
**default_fields, # pyright: ignore [reportArgumentType]
104-
)
100+
with warnings.catch_warnings():
101+
warnings.filterwarnings("ignore", category=RuntimeWarning, message="fields may not start with an underscore")
102+
return create_model( # pyright: ignore [reportCallIssue, reportUnknownVariableType]
103+
f"Partial{base.__name__}",
104+
__base__=base,
105+
__eq__=custom_eq,
106+
__hash__=base.__hash__,
107+
**default_fields, # pyright: ignore [reportArgumentType]
108+
)

workflowai/core/utils/_pydantic_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
class TestPartialModel:
10-
def test_partial_model_equals(self):
10+
def test_partial_model_equals(self, recwarn: pytest.WarningsRecorder):
1111
class SimpleModel(BaseModel):
1212
name: str
1313

@@ -16,6 +16,8 @@ class SimpleModel(BaseModel):
1616

1717
assert SimpleModel(name="John") == partial.model_validate({"name": "John"})
1818

19+
assert len(recwarn.list) == 0
20+
1921
def test_simple_model(self):
2022
class SimpleModel(BaseModel):
2123
name1: str

0 commit comments

Comments
 (0)