|
13 | 13 | ) |
14 | 14 | from tests.utils import fixtures_json |
15 | 15 | from workflowai.core.client._api import APIClient |
| 16 | +from workflowai.core.client._models import ModelInfo |
16 | 17 | from workflowai.core.client.agent import Agent |
17 | 18 | from workflowai.core.client.client import ( |
18 | 19 | WorkflowAI, |
@@ -367,3 +368,139 @@ def test_version_properties_with_model(self, agent: Agent[HelloTaskInput, HelloT |
367 | 368 | def test_version_with_models_and_version(self, agent: Agent[HelloTaskInput, HelloTaskOutput]): |
368 | 369 | # If version is explcitly provided then it takes priority and we log a warning |
369 | 370 | assert agent._sanitize_version({"version": "staging", "model": "gemini-1.5-pro-latest"}) == "staging" # pyright: ignore [reportPrivateUsage] |
| 371 | + |
| 372 | + |
| 373 | +@pytest.mark.asyncio |
| 374 | +async def test_list_models(agent: Agent[HelloTaskInput, HelloTaskOutput], httpx_mock: HTTPXMock): |
| 375 | + """Test that list_models correctly fetches and returns available models.""" |
| 376 | + # Mock the HTTP response instead of the API client method |
| 377 | + httpx_mock.add_response( |
| 378 | + url="http://localhost:8000/v1/_/agents/123/schemas/1/models", |
| 379 | + json={ |
| 380 | + "items": [ |
| 381 | + { |
| 382 | + "id": "gpt-4", |
| 383 | + "name": "GPT-4", |
| 384 | + "icon_url": "https://example.com/gpt4.png", |
| 385 | + "modes": ["chat"], |
| 386 | + "is_not_supported_reason": None, |
| 387 | + "average_cost_per_run_usd": 0.01, |
| 388 | + "is_latest": True, |
| 389 | + "metadata": { |
| 390 | + "provider_name": "OpenAI", |
| 391 | + "price_per_input_token_usd": 0.0001, |
| 392 | + "price_per_output_token_usd": 0.0002, |
| 393 | + "release_date": "2024-01-01", |
| 394 | + "context_window_tokens": 128000, |
| 395 | + "quality_index": 0.95, |
| 396 | + }, |
| 397 | + "is_default": True, |
| 398 | + "providers": ["openai"], |
| 399 | + }, |
| 400 | + { |
| 401 | + "id": "claude-3", |
| 402 | + "name": "Claude 3", |
| 403 | + "icon_url": "https://example.com/claude3.png", |
| 404 | + "modes": ["chat"], |
| 405 | + "is_not_supported_reason": None, |
| 406 | + "average_cost_per_run_usd": 0.015, |
| 407 | + "is_latest": True, |
| 408 | + "metadata": { |
| 409 | + "provider_name": "Anthropic", |
| 410 | + "price_per_input_token_usd": 0.00015, |
| 411 | + "price_per_output_token_usd": 0.00025, |
| 412 | + "release_date": "2024-03-01", |
| 413 | + "context_window_tokens": 200000, |
| 414 | + "quality_index": 0.98, |
| 415 | + }, |
| 416 | + "is_default": False, |
| 417 | + "providers": ["anthropic"], |
| 418 | + }, |
| 419 | + ], |
| 420 | + "count": 2, |
| 421 | + }, |
| 422 | + ) |
| 423 | + |
| 424 | + # Call the method |
| 425 | + models = await agent.list_models() |
| 426 | + |
| 427 | + # Verify the HTTP request was made correctly |
| 428 | + request = httpx_mock.get_request() |
| 429 | + assert request is not None, "Expected an HTTP request to be made" |
| 430 | + assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" |
| 431 | + |
| 432 | + # Verify we get back the full ModelInfo objects |
| 433 | + assert len(models) == 2 |
| 434 | + assert isinstance(models[0], ModelInfo) |
| 435 | + assert models[0].id == "gpt-4" |
| 436 | + assert models[0].name == "GPT-4" |
| 437 | + assert models[0].modes == ["chat"] |
| 438 | + assert models[0].metadata is not None |
| 439 | + assert models[0].metadata.provider_name == "OpenAI" |
| 440 | + |
| 441 | + assert isinstance(models[1], ModelInfo) |
| 442 | + assert models[1].id == "claude-3" |
| 443 | + assert models[1].name == "Claude 3" |
| 444 | + assert models[1].modes == ["chat"] |
| 445 | + assert models[1].metadata is not None |
| 446 | + assert models[1].metadata.provider_name == "Anthropic" |
| 447 | + |
| 448 | + |
| 449 | +@pytest.mark.asyncio |
| 450 | +async def test_list_models_registers_if_needed( |
| 451 | + agent_no_schema: Agent[HelloTaskInput, HelloTaskOutput], |
| 452 | + httpx_mock: HTTPXMock, |
| 453 | +): |
| 454 | + """Test that list_models registers the agent if it hasn't been registered yet.""" |
| 455 | + # Mock the registration response |
| 456 | + httpx_mock.add_response( |
| 457 | + url="http://localhost:8000/v1/_/agents", |
| 458 | + json={"id": "123", "schema_id": 2}, |
| 459 | + ) |
| 460 | + |
| 461 | + # Mock the models response with the new structure |
| 462 | + httpx_mock.add_response( |
| 463 | + url="http://localhost:8000/v1/_/agents/123/schemas/2/models", |
| 464 | + json={ |
| 465 | + "items": [ |
| 466 | + { |
| 467 | + "id": "gpt-4", |
| 468 | + "name": "GPT-4", |
| 469 | + "icon_url": "https://example.com/gpt4.png", |
| 470 | + "modes": ["chat"], |
| 471 | + "is_not_supported_reason": None, |
| 472 | + "average_cost_per_run_usd": 0.01, |
| 473 | + "is_latest": True, |
| 474 | + "metadata": { |
| 475 | + "provider_name": "OpenAI", |
| 476 | + "price_per_input_token_usd": 0.0001, |
| 477 | + "price_per_output_token_usd": 0.0002, |
| 478 | + "release_date": "2024-01-01", |
| 479 | + "context_window_tokens": 128000, |
| 480 | + "quality_index": 0.95, |
| 481 | + }, |
| 482 | + "is_default": True, |
| 483 | + "providers": ["openai"], |
| 484 | + }, |
| 485 | + ], |
| 486 | + "count": 1, |
| 487 | + }, |
| 488 | + ) |
| 489 | + |
| 490 | + # Call the method |
| 491 | + models = await agent_no_schema.list_models() |
| 492 | + |
| 493 | + # Verify both API calls were made |
| 494 | + reqs = httpx_mock.get_requests() |
| 495 | + assert len(reqs) == 2 |
| 496 | + assert reqs[0].url == "http://localhost:8000/v1/_/agents" |
| 497 | + assert reqs[1].url == "http://localhost:8000/v1/_/agents/123/schemas/2/models" |
| 498 | + |
| 499 | + # Verify we get back the full ModelInfo object |
| 500 | + assert len(models) == 1 |
| 501 | + assert isinstance(models[0], ModelInfo) |
| 502 | + assert models[0].id == "gpt-4" |
| 503 | + assert models[0].name == "GPT-4" |
| 504 | + assert models[0].modes == ["chat"] |
| 505 | + assert models[0].metadata is not None |
| 506 | + assert models[0].metadata.provider_name == "OpenAI" |
0 commit comments