diff --git a/notebooks/example.ipynb b/notebooks/example.ipynb index 2d82478..a2cbd6b 100644 --- a/notebooks/example.ipynb +++ b/notebooks/example.ipynb @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "b5fe35fe", "metadata": {}, "outputs": [], @@ -24,9 +24,9 @@ "\n", "from wags_llm.cache import InMemoryCache\n", "from wags_llm.client.bedrock import BedrockClaudeJsonClient\n", - "from wags_llm.prompts.base import BasePromptTemplate\n", - "from wags_llm.prompts.registry import PromptRegistry\n", + "from wags_llm.registry.base import Registry\n", "from wags_llm.services.structured_task import StructuredTaskRunner\n", + "from wags_llm.templates.base import PromptTemplate\n", "\n", "logging.basicConfig(\n", " stream=sys.stdout,\n", @@ -46,12 +46,12 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "ba6504ff", "metadata": {}, "outputs": [], "source": [ - "class MyPrompt(BasePromptTemplate):\n", + "class MyPrompt(PromptTemplate):\n", " name = \"mondo_id_classification\"\n", " version = \"v1\"\n", "\n", @@ -112,7 +112,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "5b14b5ce", "metadata": {}, "outputs": [ @@ -133,12 +133,12 @@ " profile_name=\"dev-account\",\n", ")\n", "\n", - "registry = PromptRegistry()\n", + "registry = Registry()\n", "registry.register(MyPrompt())\n", "\n", "service = StructuredTaskRunner(\n", " client=client,\n", - " prompt_registry=registry,\n", + " registry=registry,\n", " cache=InMemoryCache(),\n", ")" ] @@ -227,7 +227,7 @@ ], "metadata": { "kernelspec": { - "display_name": "wags-llm", + "display_name": "wags-llm (3.11.14)", "language": "python", "name": "python3" }, diff --git a/notebooks/skills.ipynb b/notebooks/skills.ipynb index 3396ee4..028ab54 100644 --- a/notebooks/skills.ipynb +++ b/notebooks/skills.ipynb @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "e9161ef9", "metadata": {}, "outputs": [], @@ -26,9 +26,9 @@ "from pydantic import BaseModel, ConfigDict\n", "\n", "from wags_llm.client.bedrock import BedrockClaudeJsonClient\n", + "from wags_llm.registry.base import Registry\n", "from wags_llm.services.structured_task import StructuredTaskRunner\n", - "from wags_llm.skills.base import BaseSkillTemplate\n", - "from wags_llm.skills.registry import SkillRegistry\n", + "from wags_llm.templates.skill_template import SkillTemplate\n", "\n", "logging.basicConfig(\n", " stream=sys.stdout,\n", @@ -45,7 +45,7 @@ "metadata": {}, "outputs": [], "source": [ - "class VariantCurationSkill(BaseSkillTemplate):\n", + "class VariantCurationSkill(SkillTemplate):\n", " skill_path = Path(\"skills/variant_curation_0.1.0.md\")\n", "\n", " def build_user_prompt(self, payload: Mapping[str, Any]) -> str:\n", @@ -96,7 +96,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "c44eeb11", "metadata": {}, "outputs": [ @@ -111,7 +111,7 @@ } ], "source": [ - "registry = SkillRegistry()\n", + "registry = Registry()\n", "registry.register(skill)\n", "\n", "MODEL_ID = \"us.anthropic.claude-sonnet-4-6\"\n", @@ -129,13 +129,13 @@ "\n", "task_runner = StructuredTaskRunner(\n", " client=llm_client,\n", - " skill_registry=registry,\n", + " registry=registry,\n", ")" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "d045cf30", "metadata": {}, "outputs": [ @@ -201,7 +201,7 @@ ], "metadata": { "kernelspec": { - "display_name": "wags-llm (3.13.5)", + "display_name": "wags-llm (3.11.14)", "language": "python", "name": "python3" }, @@ -215,7 +215,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.13.5" + "version": "3.11.14" } }, "nbformat": 4, diff --git a/src/wags_llm/prompts/__init__.py b/src/wags_llm/prompts/__init__.py deleted file mode 100644 index 33aa4a3..0000000 --- a/src/wags_llm/prompts/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Prompt interfaces and registry. - -Define and manage versioned prompt templates. -""" - -from wags_llm.prompts.base import BasePromptTemplate -from wags_llm.prompts.registry import PromptRegistry, build_empty_registry - -__all__ = [ - "BasePromptTemplate", - "PromptRegistry", - "build_empty_registry", -] diff --git a/src/wags_llm/prompts/registry.py b/src/wags_llm/prompts/registry.py deleted file mode 100644 index a0cefab..0000000 --- a/src/wags_llm/prompts/registry.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Prompt registry. - -Maps (name, version) -> prompt instance. - -Users typically: -* create prompts in their project -* register them here or pass a custom registry -""" - -import logging - -from wags_llm.prompts.base import BasePromptTemplate - -_logger = logging.getLogger(__name__) - - -class PromptRegistry: - """Store and retrieve prompts.""" - - def __init__(self) -> None: - """Initialize an empty prompt registry.""" - self._prompts: dict[tuple[str, str], BasePromptTemplate] = {} - - def register(self, prompt: BasePromptTemplate) -> None: - """Register a prompt. - - :param prompt: Prompt instance to register. - """ - _logger.debug( - "Registering prompt: name='%s', version='%s'", prompt.name, prompt.version - ) - self._prompts[(prompt.name, prompt.version)] = prompt - - def get(self, name: str, version: str) -> BasePromptTemplate: - """Retrieve a prompt. - - :param name: Prompt name. - :param version: Prompt version. - :return: Registered prompt. - :raise KeyError: If prompt is not found. - """ - try: - return self._prompts[(name, version)] - except KeyError as exc: - msg = f"Prompt not found: ({name}, {version})" - _logger.exception(msg) - raise KeyError(msg) from exc - - -def build_empty_registry() -> PromptRegistry: - """Create an empty prompt registry. - - :return: New PromptRegistry instance. - """ - return PromptRegistry() diff --git a/src/wags_llm/registry/__init__.py b/src/wags_llm/registry/__init__.py new file mode 100644 index 0000000..47766ee --- /dev/null +++ b/src/wags_llm/registry/__init__.py @@ -0,0 +1,11 @@ +"""Registry module. + +Store and retrieve versioned prompt and skill templates. +""" + +from wags_llm.registry.base import Registry, build_empty_registry + +__all__ = [ + "Registry", + "build_empty_registry", +] diff --git a/src/wags_llm/registry/base.py b/src/wags_llm/registry/base.py new file mode 100644 index 0000000..632e3d2 --- /dev/null +++ b/src/wags_llm/registry/base.py @@ -0,0 +1,97 @@ +"""Registry. + +Maps (name, version, TemplateType) -> template instance. +Template instances can be either prompts or skills. + +Users typically: +* create prompts or skills in their project +* register them here or pass a custom registry +""" + +import logging +from types import MappingProxyType + +from wags_llm.templates.base import TemplateType +from wags_llm.templates.prompt_template import PromptTemplate +from wags_llm.templates.skill_template import SkillTemplate + +_logger = logging.getLogger(__name__) + +_TEMPLATE_CLASS_TO_TYPE = MappingProxyType( + { + SkillTemplate: TemplateType.SKILL, + PromptTemplate: TemplateType.PROMPT, + } +) + + +class Registry: + """Store and retrieve prompt and skill templates.""" + + def __init__(self) -> None: + """Initialize an empty template registry.""" + self._templates: dict[ + tuple[str, str, TemplateType], PromptTemplate | SkillTemplate + ] = {} + + def register(self, template: PromptTemplate | SkillTemplate) -> None: + """Register a template. + + :param template: Template instance to register. + :raise TypeError: If the template type is unsupported. + :raise ValueError: If a template with the same name, version, and template type is already registered. + """ + for cls, mapped_type in _TEMPLATE_CLASS_TO_TYPE.items(): + if isinstance(template, cls): + template_type = mapped_type + break + else: + msg = f"Unsupported template type: {type(template)}" + raise TypeError(msg) + + key = (template.name, template.version, template_type) + + _logger.debug( + "Registering template: name='%s', version='%s', template_type='%s'", + template.name, + template.version, + template_type.value, + ) + + if key in self._templates: + msg = f"Template already registered:({template.name}, {template.version}, {template_type.value})" + _logger.error(msg) + raise ValueError(msg) + + self._templates[key] = template + + def get( + self, + name: str, + version: str, + template_type: TemplateType, + ) -> PromptTemplate | SkillTemplate: + """Retrieve a template by name, version, and template type. + + :param name: Template name. + :param version: Template version. + :param template_type: Template type. + :return: Registered template. + :raise KeyError: If template not found. + """ + key = (name, version, template_type) + + try: + return self._templates[key] + except KeyError as exc: + msg = f"Template not found: ({name}, {version}, {template_type.value})" + _logger.exception(msg) + raise KeyError(msg) from exc + + +def build_empty_registry() -> Registry: + """Create an empty registry. + + :return: New Registry instance. + """ + return Registry() diff --git a/src/wags_llm/services/structured_task.py b/src/wags_llm/services/structured_task.py index 68e1b0b..b0d79bc 100644 --- a/src/wags_llm/services/structured_task.py +++ b/src/wags_llm/services/structured_task.py @@ -1,14 +1,14 @@ -"""Run LLM prompts and return schema-validated structured outputs. +"""Run LLM prompts or skills and return schema-validated structured outputs. Inputs: -- prompt +- prompt or skill name and version - context + payload - response model (Pydantic) Returns validated output. Users extend by: -- writing prompts +- writing prompts or defining skills - defining response models (Pydantic) """ @@ -24,18 +24,8 @@ from wags_llm.cache.base import BaseCache from wags_llm.client.base import LLMJsonClient from wags_llm.client.exceptions import LLMClientError -from wags_llm.prompts.registry import ( - PromptRegistry, -) -from wags_llm.prompts.registry import ( - build_empty_registry as build_empty_prompt_registry, -) -from wags_llm.skills.registry import ( - SkillRegistry, -) -from wags_llm.skills.registry import ( - build_empty_registry as build_empty_skill_registry, -) +from wags_llm.registry import Registry, build_empty_registry +from wags_llm.templates import TemplateType _logger = logging.getLogger(__name__) @@ -59,20 +49,17 @@ class StructuredTaskRunner: def __init__( self, client: LLMJsonClient, - prompt_registry: PromptRegistry | None = None, - skill_registry: SkillRegistry | None = None, + registry: Registry | None = None, cache: BaseCache | None = None, ) -> None: """Initialize the structured task runner. - :param client: LLM client used to execute prompts. - :param prompt_registry: Registry used to resolve prompts. - :param skill_registry: Registry used to resolve skills. + :param client: LLM client used to execute prompts or skills. + :param registry: Registry used to resolve prompts or skills. :param cache: Optional cache for storing and retrieving task results. """ self.client = client - self.prompt_registry = prompt_registry or build_empty_prompt_registry() - self.skill_registry = skill_registry or build_empty_skill_registry() + self.registry = registry or build_empty_registry() self.cache = cache def execute_skill( @@ -88,40 +75,17 @@ def execute_skill( :param skill_version: Registered skill version. :param payload: JSON-serializable task data. :param response_model: Pydantic model for validation. - :return: Validated task result. + :return: Validated skill result. :raise RuntimeError: If execution or validation fails. """ - skill = self.skill_registry.get(skill_name, skill_version) - - cache_result = self._check_cache( + return self._execute( name=skill_name, version=skill_version, payload=payload, response_model=response_model, + template_type=TemplateType.SKILL, ) - if cache_result.cached is not None: - return cache_result.cached - - try: - invoke_json_response = self.client.invoke_json( - system_prompt=skill.build_system_prompt(), - user_prompt=skill.build_user_prompt(payload=payload), - json_schema=response_model.model_json_schema(), - ) - - result = response_model.model_validate(invoke_json_response.parsed_json) - - if self.cache is not None and cache_result.cache_key is not None: - self.cache.set(cache_result.cache_key, result.model_dump()) - - except (LLMClientError, ValidationError) as exc: - msg = f"Task failed: {exc}" - _logger.exception(msg) - raise RuntimeError(msg) from exc - else: - return result - def execute_prompt( self, prompt_name: str, @@ -129,20 +93,46 @@ def execute_prompt( payload: Mapping[str, Any], response_model: type[BaseModel], ) -> BaseModel: - """Execute a task and return validated output. + """Execute a prompt and return validated output. :param prompt_name: Registered prompt name. :param prompt_version: Registered prompt version. :param payload: JSON-serializable task data. :param response_model: Pydantic model for validation. + :return: Validated prompt result. + :raise RuntimeError: If execution or validation fails. + """ + return self._execute( + name=prompt_name, + version=prompt_version, + payload=payload, + response_model=response_model, + template_type=TemplateType.PROMPT, + ) + + def _execute( + self, + name: str, + version: str, + payload: Mapping[str, Any], + response_model: type[BaseModel], + template_type: TemplateType, + ) -> BaseModel: + """Execute a task and return validated output. + + :param name: Registered task name. + :param version: Registered task version. + :param payload: JSON-serializable task data. + :param response_model: Pydantic model for validation. + :param template_type: Registered template type, either skill or prompt. :return: Validated task result. :raise RuntimeError: If execution or validation fails. """ - prompt = self.prompt_registry.get(prompt_name, prompt_version) + registered_task = self.registry.get(name, version, template_type) cache_result = self._check_cache( - name=prompt_name, - version=prompt_version, + name=name, + version=version, payload=payload, response_model=response_model, ) @@ -151,8 +141,8 @@ def execute_prompt( try: invoke_json_response = self.client.invoke_json( - system_prompt=prompt.build_system_prompt(), - user_prompt=prompt.build_user_prompt(payload=payload), + system_prompt=registered_task.build_system_prompt(), + user_prompt=registered_task.build_user_prompt(payload=payload), json_schema=response_model.model_json_schema(), ) @@ -162,7 +152,7 @@ def execute_prompt( self.cache.set(cache_result.cache_key, result.model_dump()) except (LLMClientError, ValidationError) as exc: - msg = f"Task failed: {exc}" + msg = f"{template_type.value} execution failed for {name} version {version}: {exc}" _logger.exception(msg) raise RuntimeError(msg) from exc else: diff --git a/src/wags_llm/templates/__init__.py b/src/wags_llm/templates/__init__.py new file mode 100644 index 0000000..860ede4 --- /dev/null +++ b/src/wags_llm/templates/__init__.py @@ -0,0 +1,15 @@ +"""Prompt interfaces and registry. + +Define and manage versioned prompt and skill templates. +""" + +from wags_llm.templates.base import BaseTemplate, TemplateType +from wags_llm.templates.prompt_template import PromptTemplate +from wags_llm.templates.skill_template import SkillTemplate + +__all__ = [ + "BaseTemplate", + "PromptTemplate", + "SkillTemplate", + "TemplateType", +] diff --git a/src/wags_llm/prompts/base.py b/src/wags_llm/templates/base.py similarity index 68% rename from src/wags_llm/prompts/base.py rename to src/wags_llm/templates/base.py index 63b8792..fe63b6b 100644 --- a/src/wags_llm/prompts/base.py +++ b/src/wags_llm/templates/base.py @@ -1,18 +1,26 @@ -"""Prompt interface. +"""Base template interface. Users extend this to define new tasks. """ from abc import ABC, abstractmethod from collections.abc import Mapping +from enum import StrEnum from typing import Any -class BasePromptTemplate(ABC): - """Base prompt template. +class TemplateType(StrEnum): + """Enum for template types supported by StructuredTaskRunner.""" - :var name: Prompt name. - :var version: Prompt version. + SKILL = "skill" + PROMPT = "prompt" + + +class BaseTemplate(ABC): + """Base template. + + :var name: Template name. + :var version: Template version. """ name: str diff --git a/src/wags_llm/templates/prompt_template.py b/src/wags_llm/templates/prompt_template.py new file mode 100644 index 0000000..b2fe43d --- /dev/null +++ b/src/wags_llm/templates/prompt_template.py @@ -0,0 +1,15 @@ +"""Prompt interface. + +Users extend this to define new tasks. +""" + +from wags_llm.templates.base import BaseTemplate, TemplateType + + +class PromptTemplate(BaseTemplate): + """Prompt template. + + :var template_type: Identifies this as a prompt template; always set to TemplateType.PROMPT + """ + + template_type = TemplateType.PROMPT diff --git a/src/wags_llm/templates/skill_template.py b/src/wags_llm/templates/skill_template.py new file mode 100644 index 0000000..35475cf --- /dev/null +++ b/src/wags_llm/templates/skill_template.py @@ -0,0 +1,102 @@ +"""Skill interface. + +Users extend this to define new skill inputs. +""" + +import logging +import re +from pathlib import Path + +from wags_llm.templates.base import BaseTemplate, TemplateType + +logger = logging.getLogger(__name__) + + +class SkillTemplateError(Exception): + """Raise custom exceptions for SkillTemplateError.""" + + +class SkillTemplate(BaseTemplate): + """Base skill template. + + :var skill_path: Path to the skill `.md` file. Must follow the format + {skill_name}_{version}.md (e.g. entity_detection_v1.md). + If the filename does not follow this format, a SkillTemplateError + will be raised on initialization. + :var template_type: Identifies this as a skill template; always set to TemplateType.SKILL + """ + + skill_path: Path + template_type = TemplateType.SKILL + _skill_file_pattern = re.compile(r"^(?P.+)_(?P[^_]+)\.md$") + + def __init__(self) -> None: + """Initialize the skill template and validate the skill filename format. + + :raise SkillTemplateError: If skill_path does not follow the required format. + """ + self._name, self._version = self._get_skill_name_and_version() + + @property + def name(self) -> str: + """Derive skill name from the file stem. + + :return: Skill name string. + """ + return self._name + + @property + def version(self) -> str: + """Derive skill version from the file stem. + + :return: Skill version string. + """ + return self._version + + def load_skill(self) -> str: + """Load skill instructions from file. + + :return: Skill instruction string. + :raise SkillTemplateError: If skill_path does not exist, if the file + contains invalid UTF-8, or if the file cannot be read. + """ + logger.debug("Loading skill from path: %s", self.skill_path) + if not self.skill_path.exists(): + msg = f"Skill path not found: {self.skill_path}" + raise SkillTemplateError(msg) + + try: + content = self.skill_path.read_text(encoding="utf-8") + except UnicodeDecodeError as exc: + msg = f"Skill file is not valid UTF-8: {self.skill_path}" + logger.exception(msg) + raise SkillTemplateError(msg) from exc + except OSError as exc: + msg = f"Failed to read skill file: {self.skill_path}" + logger.exception(msg) + raise SkillTemplateError(msg) from exc + + logger.info("Loaded skill from path: %s", self.skill_path) + return content + + def build_system_prompt(self) -> str: + """Build the system prompt by loading instructions from the skill file. + + :return: Skill instruction string. + :raise SkillTemplateError: If skill_path does not exist, if the file + contains invalid UTF-8, or if the file cannot be read. + """ + return self.load_skill() + + def _get_skill_name_and_version(self) -> tuple[str, str]: + """Parse the skill filename to extract name and version. + + :return: Tuple of (name, version) strings. + :raise SkillTemplateError: If filename does not follow the required format. + """ + name = self.skill_path.name + match = self._skill_file_pattern.search(name) + if not match: + msg = f"Skill filename must follow the format '{{skill_name}}_{{version}}.md', got path: '{self.skill_path}'" + raise SkillTemplateError(msg) + return match.group("name"), match.group("version") diff --git a/tests/unit/skills/test_skill_v1.md b/tests/examples/test_example_0.1.0.md similarity index 100% rename from tests/unit/skills/test_skill_v1.md rename to tests/examples/test_example_0.1.0.md diff --git a/tests/integration/services/test_json_task.py b/tests/integration/services/test_json_task.py index 4d833bc..65ccedb 100644 --- a/tests/integration/services/test_json_task.py +++ b/tests/integration/services/test_json_task.py @@ -7,12 +7,12 @@ from wags_llm.cache.in_memory import InMemoryCache from wags_llm.client.base import InvokeJsonResponse, LLMJsonClient -from wags_llm.prompts.base import BasePromptTemplate -from wags_llm.prompts.registry import PromptRegistry +from wags_llm.registry import Registry from wags_llm.services.structured_task import StructuredTaskRunner +from wags_llm.templates.prompt_template import PromptTemplate -class DummyPrompt(BasePromptTemplate): +class DummyPrompt(PromptTemplate): """Simple prompt for service tests.""" name = "test_task" @@ -81,12 +81,12 @@ class ResultModel(BaseModel): def test_run_success(): """Test that run method works correctly""" - registry = PromptRegistry() + registry = Registry() registry.register(DummyPrompt()) service = StructuredTaskRunner( client=DummyClient(), - prompt_registry=registry, + registry=registry, ) result = service.execute_prompt( @@ -101,14 +101,14 @@ def test_run_success(): def test_run_uses_cache(): """Test that run method works correctly with cache""" - registry = PromptRegistry() + registry = Registry() registry.register(DummyPrompt()) client = DummyClient() cache = InMemoryCache() service = StructuredTaskRunner( client=client, - prompt_registry=registry, + registry=registry, cache=cache, ) @@ -132,14 +132,14 @@ def test_run_uses_cache(): def test_run_cache_miss_for_different_payload(): """Test that run method works correctly with cache that uses different payload""" - registry = PromptRegistry() + registry = Registry() registry.register(DummyPrompt()) client = DummyClient() cache = InMemoryCache() service = StructuredTaskRunner( client=client, - prompt_registry=registry, + registry=registry, cache=cache, ) @@ -161,15 +161,15 @@ def test_run_cache_miss_for_different_payload(): def test_run_validation_error(): """Test that run raises error when response validation fails.""" - registry = PromptRegistry() + registry = Registry() registry.register(DummyPrompt()) service = StructuredTaskRunner( client=BadClient(), - prompt_registry=registry, + registry=registry, ) - with pytest.raises(RuntimeError, match="Task failed"): + with pytest.raises(RuntimeError, match="prompt execution failed"): service.execute_prompt( prompt_name="test_task", prompt_version="v1", diff --git a/tests/registry/test_registry.py b/tests/registry/test_registry.py new file mode 100644 index 0000000..4a60834 --- /dev/null +++ b/tests/registry/test_registry.py @@ -0,0 +1,128 @@ +"""Test that Registry works correctly""" + +import re +from collections.abc import Mapping +from pathlib import Path +from typing import Any + +import pytest + +from wags_llm.registry import Registry, build_empty_registry +from wags_llm.templates import TemplateType +from wags_llm.templates.prompt_template import PromptTemplate +from wags_llm.templates.skill_template import SkillTemplate, SkillTemplateError + + +class DummyPrompt(PromptTemplate): + """Simple prompt for registry tests.""" + + name = "test_example" + version = "0.1.0" + + def build_system_prompt(self) -> str: + """Build the system prompt. + + :return: System prompt string. + """ + return "Return valid JSON only." + + def build_user_prompt(self, payload: Mapping[str, Any]) -> str: + """Build the user prompt. + + :param payload: JSON-serializable task data. + + Example: + payload = {"text": "hello"} + + :return: User prompt string. + """ + return f"Payload: {payload}" + + +class DummySkill(SkillTemplate): + """Simple skill for registry tests.""" + + skill_path = Path("tests/examples/test_example_0.1.0.md") + + def build_user_prompt(self, payload: Mapping[str, Any]) -> str: + """Build the user prompt. + + :param payload: JSON-serializable task data. + + Example: + payload = {"text": "hello"} + + :return: User prompt string. + """ + return f"Payload: {payload}" + + +def test_register_and_get_prompt(): + """Register and retrieve a prompt.""" + registry = Registry() + prompt = DummyPrompt() + + registry.register(prompt) + + assert registry.get("test_example", "0.1.0", TemplateType.PROMPT) is prompt + + +def test_register_and_get_skill(): + """Register and retrieve a skill.""" + registry = Registry() + skill = DummySkill() + + registry.register(skill) + + assert registry.get("test_example", "0.1.0", TemplateType.SKILL) is skill + + +def test_build_empty_registry(): + """Test that build_empty_registry works correctly and prompt registry raises KeyError when no prompts are registered""" + registry = build_empty_registry() + with pytest.raises( + KeyError, + match=re.escape("'Template not found: (test_example, 0.1.0, prompt)'"), + ): + assert registry.get("test_example", "0.1.0", TemplateType.PROMPT) + + +def test_invalid_skill_filename(): + """Test that an invalid skill filename raises SkillTemplateError.""" + + class InvalidSkill(SkillTemplate): + skill_path = Path("tests/examples/invalid.md") + + def build_user_prompt(self, payload) -> str: + return f"Payload: {payload}" + + with pytest.raises(SkillTemplateError): + _ = InvalidSkill().name + + +def test_prompt_and_skill_can_share_name_and_version(): + """Register prompt and skill with same name/version.""" + registry = Registry() + + prompt = DummyPrompt() + skill = DummySkill() + + registry.register(prompt) + registry.register(skill) + + assert registry.get("test_example", "0.1.0", TemplateType.PROMPT) is prompt + assert registry.get("test_example", "0.1.0", TemplateType.SKILL) is skill + + +def test_registering_duplicate_skill_raises_value_error(): + """Raise ValueError when registering duplicate skill name/version.""" + registry = Registry() + skill = DummySkill() + + registry.register(skill) + + with pytest.raises( + ValueError, + match=re.escape("Template already registered:(test_example, 0.1.0, skill)"), + ): + registry.register(DummySkill()) diff --git a/tests/unit/prompts/test_registry.py b/tests/unit/prompts/test_registry.py deleted file mode 100644 index f8f65cd..0000000 --- a/tests/unit/prompts/test_registry.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Test that PromptRegistry works correctly""" - -import re -from collections.abc import Mapping -from typing import Any - -import pytest - -from wags_llm.prompts.base import BasePromptTemplate -from wags_llm.prompts.registry import PromptRegistry, build_empty_registry - - -class DummyPrompt(BasePromptTemplate): - """Simple prompt for registry tests.""" - - name = "test_task" - version = "v1" - - def build_system_prompt(self) -> str: - """Build the system prompt. - - :return: System prompt string. - """ - return "Return valid JSON only." - - def build_user_prompt(self, payload: Mapping[str, Any]) -> str: - """Build the user prompt. - - :param payload: JSON-serializable task data. - - Example: - payload = {"text": "hello"} - - :return: User prompt string. - """ - return f"Payload: {payload}" - - -def test_register_and_get_prompt(): - """Register and retrieve a prompt.""" - registry = PromptRegistry() - prompt = DummyPrompt() - - registry.register(prompt) - - assert registry.get("test_task", "v1") is prompt - - -def test_build_empty_registry(): - """Test that build_empty_registry works correctly and prompt registry raises KeyError when no prompts are registered""" - registry = build_empty_registry() - with pytest.raises( - KeyError, match=re.escape("'Prompt not found: (test_task, v1)'") - ): - assert registry.get("test_task", "v1") diff --git a/tests/unit/skills/test_skill_json_task.py b/tests/unit/skills/test_skill_json_task.py index 68cceed..efdefcf 100644 --- a/tests/unit/skills/test_skill_json_task.py +++ b/tests/unit/skills/test_skill_json_task.py @@ -8,25 +8,25 @@ from wags_llm.cache.in_memory import InMemoryCache from wags_llm.client.base import InvokeJsonResponse, LLMJsonClient +from wags_llm.registry import Registry from wags_llm.services.structured_task import StructuredTaskRunner -from wags_llm.skills.base import BaseSkillTemplate, SkillTemplateError -from wags_llm.skills.registry import SkillRegistry +from wags_llm.templates.skill_template import SkillTemplate, SkillTemplateError -class DummySkill(BaseSkillTemplate): +class DummySkill(SkillTemplate): """Simple skill for service tests.""" - skill_path = Path("tests/unit/skills/test_skill_v1.md") + skill_path = Path("tests/examples/test_example_0.1.0.md") def build_user_prompt(self, payload) -> str: """Build the user prompt.""" return f"Payload: {payload}" -class MissingFileSkill(BaseSkillTemplate): +class MissingFileSkill(SkillTemplate): """Missing skill file for service tests.""" - skill_path = Path("tests/unit/skills/does_not_exist_v1.md") + skill_path = Path("tests/examples/does_not_exist_0.1.0.md") def build_user_prompt(self, payload) -> str: """Build the user prompt.""" @@ -87,17 +87,17 @@ class ResultModel(BaseModel): def test_execute_skill_success(): """Test that execute_skill works correctly.""" - registry = SkillRegistry() + registry = Registry() registry.register(DummySkill()) service = StructuredTaskRunner( client=DummyClient(), - skill_registry=registry, + registry=registry, ) result = service.execute_skill( - skill_name="test_skill", - skill_version="v1", + skill_name="test_example", + skill_version="0.1.0", payload={"text": "hello"}, response_model=ResultModel, ) @@ -108,18 +108,18 @@ def test_execute_skill_success(): def test_execute_skill_file_not_found(): """Test that execute_skill raises FileNotFoundError when skill file does not exist.""" - registry = SkillRegistry() + registry = Registry() registry.register(MissingFileSkill()) service = StructuredTaskRunner( client=DummyClient(), - skill_registry=registry, + registry=registry, ) with pytest.raises(SkillTemplateError): service.execute_skill( skill_name="does_not_exist", - skill_version="v1", + skill_version="0.1.0", payload={"text": "hello"}, response_model=ResultModel, ) @@ -127,26 +127,26 @@ def test_execute_skill_file_not_found(): def test_execute_skill_uses_cache(): """Test that execute_skill works correctly with cache.""" - registry = SkillRegistry() + registry = Registry() registry.register(DummySkill()) client = DummyClient() cache = InMemoryCache() service = StructuredTaskRunner( client=client, - skill_registry=registry, + registry=registry, cache=cache, ) result1 = service.execute_skill( - skill_name="test_skill", - skill_version="v1", + skill_name="test_example", + skill_version="0.1.0", payload={"x": 1}, response_model=ResultModel, ) result2 = service.execute_skill( - skill_name="test_skill", - skill_version="v1", + skill_name="test_example", + skill_version="0.1.0", payload={"x": 1}, response_model=ResultModel, ) @@ -158,26 +158,26 @@ def test_execute_skill_uses_cache(): def test_execute_skill_cache_miss_for_different_payload(): """Test that execute_skill cache misses on different payload.""" - registry = SkillRegistry() + registry = Registry() registry.register(DummySkill()) client = DummyClient() cache = InMemoryCache() service = StructuredTaskRunner( client=client, - skill_registry=registry, + registry=registry, cache=cache, ) service.execute_skill( - skill_name="test_skill", - skill_version="v1", + skill_name="test_example", + skill_version="0.1.0", payload={"x": 1}, response_model=ResultModel, ) service.execute_skill( - skill_name="test_skill", - skill_version="v1", + skill_name="test_example", + skill_version="0.1.0", payload={"x": 2}, response_model=ResultModel, ) @@ -187,18 +187,18 @@ def test_execute_skill_cache_miss_for_different_payload(): def test_execute_skill_validation_error(): """Test that execute_skill raises RuntimeError when response validation fails.""" - registry = SkillRegistry() + registry = Registry() registry.register(DummySkill()) service = StructuredTaskRunner( client=BadClient(), - skill_registry=registry, + registry=registry, ) - with pytest.raises(RuntimeError, match="Task failed"): + with pytest.raises(RuntimeError, match="skill execution failed"): service.execute_skill( - skill_name="test_skill", - skill_version="v1", + skill_name="test_example", + skill_version="0.1.0", payload={"text": "hello"}, response_model=ResultModel, ) diff --git a/tests/unit/skills/test_skill_registry.py b/tests/unit/skills/test_skill_registry.py deleted file mode 100644 index c58282d..0000000 --- a/tests/unit/skills/test_skill_registry.py +++ /dev/null @@ -1,55 +0,0 @@ -import re -from collections.abc import Mapping -from pathlib import Path -from typing import Any - -import pytest - -from wags_llm.skills.base import BaseSkillTemplate, SkillTemplateError -from wags_llm.skills.registry import SkillRegistry, build_empty_registry - - -class DummySkill(BaseSkillTemplate): - skill_path = Path("tests/unit/skills/test_skill_v1.md") - - def build_user_prompt(self, payload: Mapping[str, Any]) -> str: - """Build the user prompt. - - :param payload: JSON-serializable task data. - - Example: - payload = {"text": "hello"} - - :return: User prompt string. - """ - return f"Payload: {payload}" - - -def test_register_and_get_skill(): - registry = SkillRegistry() - skill = DummySkill() - - registry.register(skill) - - assert registry.get("test_skill", "v1") is skill - - -def test_build_empty_registry(): - registry = build_empty_registry() - with pytest.raises( - KeyError, match=re.escape("'Skill not found: (test_skill, v1)'") - ): - assert registry.get("test_skill", "v1") - - -def test_invalid_skill_filename(): - """Test that an invalid skill filename raises SkillTemplateError.""" - - class InvalidSkill(BaseSkillTemplate): - skill_path = Path("tests/unit/skills/invalid.md") - - def build_user_prompt(self, payload) -> str: - return f"Payload: {payload}" - - with pytest.raises(SkillTemplateError): - _ = InvalidSkill().name