Skip to content

Commit fb5e7e4

Browse files
authored
Merge pull request #62 from WorkflowAI/feature/run-completions
feat: Add fetch_completions method to Run class
2 parents 9435e5e + 37c97aa commit fb5e7e4

File tree

8 files changed

+222
-19
lines changed

8 files changed

+222
-19
lines changed

tests/fixtures/completions.json

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
{
2+
"completions": [
3+
{
4+
"messages": [
5+
{
6+
"role": "system",
7+
"content": "I am instructions"
8+
},
9+
{
10+
"role": "user",
11+
"content": "I am user message"
12+
}
13+
],
14+
"response": "This is a test response",
15+
"usage": {
16+
"completion_token_count": 222,
17+
"completion_cost_usd": 0.00013319999999999999,
18+
"prompt_token_count": 1230,
19+
"prompt_cost_usd": 0.00018449999999999999,
20+
"model_context_window_size": 1048576
21+
}
22+
}
23+
]
24+
}

tests/fixtures/task_example.json

Lines changed: 0 additions & 16 deletions
This file was deleted.

workflowai/core/client/_models.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from workflowai.core._common_types import OutputValidator
77
from workflowai.core.domain.cache_usage import CacheUsage
8+
from workflowai.core.domain.completion import Completion
89
from workflowai.core.domain.run import Run
910
from workflowai.core.domain.task import AgentOutput
1011
from workflowai.core.domain.tool_call import ToolCall as DToolCall
@@ -160,6 +161,7 @@ class CreateAgentResponse(BaseModel):
160161

161162
class ModelMetadata(BaseModel):
162163
"""Metadata for a model."""
164+
163165
provider_name: str = Field(description="Name of the model provider")
164166
price_per_input_token_usd: Optional[float] = Field(None, description="Cost per input token in USD")
165167
price_per_output_token_usd: Optional[float] = Field(None, description="Cost per output token in USD")
@@ -170,6 +172,7 @@ class ModelMetadata(BaseModel):
170172

171173
class ModelInfo(BaseModel):
172174
"""Information about a model."""
175+
173176
id: str = Field(description="Unique identifier for the model")
174177
name: str = Field(description="Display name of the model")
175178
icon_url: Optional[str] = Field(None, description="URL for the model's icon")
@@ -187,11 +190,19 @@ class ModelInfo(BaseModel):
187190

188191
T = TypeVar("T")
189192

193+
190194
class Page(BaseModel, Generic[T]):
191195
"""A generic paginated response."""
196+
192197
items: list[T] = Field(description="List of items in this page")
193198
count: Optional[int] = Field(None, description="Total number of items available")
194199

195200

196201
class ListModelsResponse(Page[ModelInfo]):
197202
"""Response from the list models API endpoint."""
203+
204+
205+
class CompletionsResponse(BaseModel):
206+
"""Response from the completions API endpoint."""
207+
208+
completions: list[Completion]

workflowai/core/client/agent.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from workflowai.core._common_types import BaseRunParams, OutputValidator, VersionRunParams
1010
from workflowai.core.client._api import APIClient
1111
from workflowai.core.client._models import (
12+
CompletionsResponse,
1213
CreateAgentRequest,
1314
CreateAgentResponse,
1415
ListModelsResponse,
@@ -24,6 +25,7 @@
2425
intolerant_validator,
2526
tolerant_validator,
2627
)
28+
from workflowai.core.domain.completion import Completion
2729
from workflowai.core.domain.errors import BaseError, WorkflowAIError
2830
from workflowai.core.domain.run import Run
2931
from workflowai.core.domain.task import AgentInput, AgentOutput
@@ -493,3 +495,18 @@ async def list_models(self) -> list[ModelInfo]:
493495
returns=ListModelsResponse,
494496
)
495497
return response.items
498+
499+
async def fetch_completions(self, run_id: str) -> list[Completion]:
500+
"""Fetch the completions for a run.
501+
502+
Args:
503+
run_id (str): The id of the run to fetch completions for.
504+
505+
Returns:
506+
CompletionsResponse: The completions for the run.
507+
"""
508+
raw = await self.api.get(
509+
f"/v1/_/agents/{self.agent_id}/runs/{run_id}/completions",
510+
returns=CompletionsResponse,
511+
)
512+
return raw.completions

workflowai/core/client/agent_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from workflowai.core.client.client import (
2020
WorkflowAI,
2121
)
22+
from workflowai.core.domain.completion import Completion, CompletionUsage, Message
2223
from workflowai.core.domain.errors import WorkflowAIError
2324
from workflowai.core.domain.run import Run
2425
from workflowai.core.domain.version_properties import VersionProperties
@@ -539,3 +540,31 @@ async def test_list_models_registers_if_needed(
539540
assert models[0].modes == ["chat"]
540541
assert models[0].metadata is not None
541542
assert models[0].metadata.provider_name == "OpenAI"
543+
544+
545+
class TestFetchCompletions:
546+
async def test_fetch_completions(self, agent: Agent[HelloTaskInput, HelloTaskOutput], httpx_mock: HTTPXMock):
547+
"""Test that fetch_completions correctly fetches and returns completions."""
548+
# Mock the HTTP response instead of the API client method
549+
httpx_mock.add_response(
550+
url="http://localhost:8000/v1/_/agents/123/runs/1/completions",
551+
json=fixtures_json("completions.json"),
552+
)
553+
554+
completions = await agent.fetch_completions("1")
555+
assert completions == [
556+
Completion(
557+
messages=[
558+
Message(role="system", content="I am instructions"),
559+
Message(role="user", content="I am user message"),
560+
],
561+
response="This is a test response",
562+
usage=CompletionUsage(
563+
completion_token_count=222,
564+
completion_cost_usd=0.00013319999999999999,
565+
prompt_token_count=1230,
566+
prompt_cost_usd=0.00018449999999999999,
567+
model_context_window_size=1048576,
568+
),
569+
),
570+
]
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from typing import Optional
2+
3+
from pydantic import BaseModel, Field
4+
5+
6+
class CompletionUsage(BaseModel):
7+
"""Usage information for a completion."""
8+
9+
completion_token_count: Optional[int] = None
10+
completion_cost_usd: Optional[float] = None
11+
reasoning_token_count: Optional[int] = None
12+
prompt_token_count: Optional[int] = None
13+
prompt_token_count_cached: Optional[int] = None
14+
prompt_cost_usd: Optional[float] = None
15+
prompt_audio_token_count: Optional[int] = None
16+
prompt_audio_duration_seconds: Optional[float] = None
17+
prompt_image_count: Optional[int] = None
18+
model_context_window_size: Optional[int] = None
19+
20+
21+
class Message(BaseModel):
22+
"""A message in a completion."""
23+
24+
role: str = ""
25+
content: str = ""
26+
27+
28+
class Completion(BaseModel):
29+
"""A completion from the model."""
30+
31+
messages: list[Message] = Field(default_factory=list)
32+
response: Optional[str] = None
33+
usage: CompletionUsage = Field(default_factory=CompletionUsage)
34+
35+
36+
class CompletionsResponse(BaseModel):
37+
"""Response from the completions API endpoint."""
38+
39+
completions: list[Completion]

workflowai/core/domain/run.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from workflowai import env
88
from workflowai.core import _common_types
99
from workflowai.core.client import _types
10+
from workflowai.core.domain.completion import Completion
1011
from workflowai.core.domain.errors import BaseError
1112
from workflowai.core.domain.task import AgentOutput
1213
from workflowai.core.domain.tool_call import ToolCall, ToolCallRequest, ToolCallResult
@@ -130,6 +131,23 @@ def __str__(self) -> str:
130131
def run_url(self):
131132
return f"{env.WORKFLOWAI_APP_URL}/_/agents/{self.agent_id}/runs/{self.id}"
132133

134+
async def fetch_completions(self) -> list[Completion]:
135+
"""Fetch the completions for this run.
136+
137+
Returns:
138+
CompletionsResponse: The completions response containing a list of completions
139+
with their messages, responses and usage information.
140+
141+
Raises:
142+
ValueError: If the agent is not set or if the run id is not set.
143+
"""
144+
if not self._agent:
145+
raise ValueError("Agent is not set")
146+
if not self.id:
147+
raise ValueError("Run id is not set")
148+
149+
return await self._agent.fetch_completions(self.id)
150+
133151

134152
class _AgentBase(Protocol, Generic[AgentOutput]):
135153
async def reply(
@@ -141,3 +159,5 @@ async def reply(
141159
) -> "Run[AgentOutput]":
142160
"""Reply to a run. Either a user_message or tool_results must be provided."""
143161
...
162+
163+
async def fetch_completions(self, run_id: str) -> list[Completion]: ...

workflowai/core/domain/run_test.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
import pytest
44
from pydantic import BaseModel
55

6-
from workflowai.core.domain.run import Run
6+
from workflowai.core.domain.completion import Completion, CompletionUsage, Message
7+
from workflowai.core.domain.run import (
8+
Run,
9+
_AgentBase, # pyright: ignore [reportPrivateUsage]
10+
)
711
from workflowai.core.domain.version import Version
812
from workflowai.core.domain.version_properties import VersionProperties
913

@@ -13,8 +17,14 @@ class _TestOutput(BaseModel):
1317

1418

1519
@pytest.fixture
16-
def run1() -> Run[_TestOutput]:
17-
return Run[_TestOutput](
20+
def mock_agent() -> Mock:
21+
mock = Mock(spec=_AgentBase)
22+
return mock
23+
24+
25+
@pytest.fixture
26+
def run1(mock_agent: Mock) -> Run[_TestOutput]:
27+
run = Run[_TestOutput](
1828
id="run-id",
1929
agent_id="agent-id",
2030
schema_id=1,
@@ -26,6 +36,8 @@ def run1() -> Run[_TestOutput]:
2636
tool_calls=[],
2737
tool_call_requests=[],
2838
)
39+
run._agent = mock_agent # pyright: ignore [reportPrivateUsage]
40+
return run
2941

3042

3143
@pytest.fixture
@@ -128,3 +140,70 @@ class TestRunURL:
128140
@patch("workflowai.env.WORKFLOWAI_APP_URL", "https://workflowai.hello")
129141
def test_run_url(self, run1: Run[_TestOutput]):
130142
assert run1.run_url == "https://workflowai.hello/_/agents/agent-id/runs/run-id"
143+
144+
145+
class TestFetchCompletions:
146+
"""Tests for the fetch_completions method of the Run class."""
147+
148+
# Test that the underlying agent is called with the proper run id
149+
async def test_fetch_completions_success(self, run1: Run[_TestOutput], mock_agent: Mock):
150+
mock_agent.fetch_completions.return_value = [
151+
Completion(
152+
messages=[
153+
Message(role="system", content="You are a helpful assistant"),
154+
Message(role="user", content="Hello"),
155+
Message(role="assistant", content="Hi there!"),
156+
],
157+
response="Hi there!",
158+
usage=CompletionUsage(
159+
completion_token_count=3,
160+
completion_cost_usd=0.001,
161+
reasoning_token_count=10,
162+
prompt_token_count=20,
163+
prompt_token_count_cached=0,
164+
prompt_cost_usd=0.002,
165+
prompt_audio_token_count=0,
166+
prompt_audio_duration_seconds=0,
167+
prompt_image_count=0,
168+
model_context_window_size=32000,
169+
),
170+
),
171+
]
172+
173+
# Call fetch_completions
174+
completions = await run1.fetch_completions()
175+
176+
# Verify the API was called correctly
177+
mock_agent.fetch_completions.assert_called_once_with("run-id")
178+
179+
# Verify the response
180+
assert len(completions) == 1
181+
completion = completions[0]
182+
assert len(completion.messages) == 3
183+
assert completion.messages[0].role == "system"
184+
assert completion.messages[0].content == "You are a helpful assistant"
185+
assert completion.response == "Hi there!"
186+
assert completion.usage.completion_token_count == 3
187+
assert completion.usage.completion_cost_usd == 0.001
188+
189+
# Test that fetch_completions fails appropriately when the agent is not set:
190+
# 1. This is a common error case that occurs when a Run object is created without an agent
191+
# 2. The method should fail fast with a clear error message before attempting any API calls
192+
# 3. This protects users from confusing errors that would occur if we tried to use the API client
193+
async def test_fetch_completions_no_agent(self, run1: Run[_TestOutput]):
194+
run1._agent = None # pyright: ignore [reportPrivateUsage]
195+
with pytest.raises(ValueError, match="Agent is not set"):
196+
await run1.fetch_completions()
197+
198+
# Test that fetch_completions fails appropriately when the run ID is not set:
199+
# 1. The run ID is required to construct the API endpoint URL
200+
# 2. Without it, we can't make a valid API request
201+
# 3. This validates that we fail fast with a clear error message
202+
# 4. This should never happen in practice (as Run objects always have an ID),
203+
# but we test it for completeness and to ensure robust error handling
204+
async def test_fetch_completions_no_id(self, run1: Run[_TestOutput]):
205+
mock_agent = Mock()
206+
run1._agent = mock_agent # pyright: ignore [reportPrivateUsage]
207+
run1.id = "" # Empty ID
208+
with pytest.raises(ValueError, match="Run id is not set"):
209+
await run1.fetch_completions()

0 commit comments

Comments
 (0)