diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..23cf457 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +.mypy_cache +.venv +uv.lock diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..6324d40 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.14 diff --git a/docs/README-TEMPLATE.md b/docs/README-TEMPLATE.md new file mode 100644 index 0000000..e69de29 diff --git a/examples/database.py b/examples/database.py new file mode 100644 index 0000000..1354bb2 --- /dev/null +++ b/examples/database.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from typing import Self + +from pydantic import BaseModel + +from testing_utils import BaseUtils + + +class Transaction: + """ + TODO: + """ + + def __init__(self, utils: DatabaseUtils) -> None: + self._utils = utils + + async def commit(self) -> DatabaseUtils: + await self._utils.commit(self) + return self._utils + + +class DatabaseUtils(BaseUtils[Transaction, BaseModel]): + """ + TODO: + """ + + # TODO: add types + def __init__(self, db, **kwargs) -> None: + super().__init__(**kwargs) + self._db = db + + def fork(self, label: str = "") -> Self: + return self._fork(DatabaseUtils, label) + + def start(self) -> Transaction: + return Transaction(self) diff --git a/examples/endpoints.py b/examples/endpoints.py new file mode 100644 index 0000000..c01b6a5 --- /dev/null +++ b/examples/endpoints.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from typing import Self + +from pydantic import BaseModel + +from testing_utils import BaseUtils, Model + + +class Transaction: + """ + TODO: + """ + + def __init__(self, utils: EndpointUtils) -> None: + self._utils = utils + + async def commit(self) -> EndpointUtils: + await self._utils.commit(self) + return self._utils + + +class EndpointUtils(BaseUtils[Transaction, BaseModel]): + """ + TODO: + """ + + # TODO: add types + def __init__(self, client, **kwargs) -> None: + super().__init__(**kwargs) + self._client = client + self._models = [ + Model(name="", requires=[]) + ] + + def fork(self, label: str = "") -> Self: + return self._fork(EndpointUtils, label) + + def start(self) -> Transaction: + return Transaction(self) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..204ce8c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,15 @@ +[project] +name = "testing-utils" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.14" +dependencies = [] + +[dependency-groups] +dev = [ + "black>=25.11.0", + "isort>=7.0.0", + "mypy>=1.18.2", + "pytest>=9.0.1", +] diff --git a/testing_utils/__init__.py b/testing_utils/__init__.py new file mode 100644 index 0000000..42616f7 --- /dev/null +++ b/testing_utils/__init__.py @@ -0,0 +1,8 @@ +from .utils import BaseUtils +from .models import Model, FixtureSpec + +__all__ = [ + "BaseUtils", + "Model", + "FixtureSpec", +] diff --git a/testing_utils/models.py b/testing_utils/models.py new file mode 100644 index 0000000..93bbc31 --- /dev/null +++ b/testing_utils/models.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass +from typing import Any, Optional, TypeVar + + +@dataclass +class Model: + """ + TODO: + """ + name: str + requires: list[str] + plural: Optional[str] = None + + @property + def plural_name(self) -> str: + if self.plural is not None: + return self.plural + + return f"{self.name}s" + + +@dataclass +class FixtureSpec: + name: str + args: dict[str, Any] + + +@dataclass +class ModelWithFixture: + """ + TODO: + """ + model: Model + fixture: FixtureSpec + + +T = TypeVar("T") + + +def or_(*args: T | None) -> T: + for arg in args: + if arg is not None: + return arg + + assert False, "or_() was given no non-None values" diff --git a/testing_utils/py.typed b/testing_utils/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/testing_utils/sort.py b/testing_utils/sort.py new file mode 100644 index 0000000..a0a5155 --- /dev/null +++ b/testing_utils/sort.py @@ -0,0 +1,65 @@ +from .models import Model, FixtureSpec, ModelWithFixture + + +def topological_sort_and_fill( + models: list[Model], + fixtures: list[FixtureSpec], +) -> list[ModelWithFixture]: + models_with_fixtures: list[ModelWithFixture] = [] + fixture_model_names = {fixture.name for fixture in fixtures} + + for fixture in fixtures: + model = next((m for m in models if m.name == fixture.name), None) + + assert model is not None, f"Model {fixture.name} not found in models list" + + models_with_fixtures.append(ModelWithFixture(model=model, fixture=fixture)) + + visited_names = set[str]() + stack: list[ModelWithFixture] = [] + + def dfs(node: ModelWithFixture) -> None: + visited_names.add(node.model.name) + + for dependency in node.model.requires: + if ( + dependency not in fixture_model_names + and dependency not in visited_names + ): + # user didn't specify this model, so add in default + model_to_add = next( + (m for m in models if m.name == dependency), + None, + ) + + msg = f"Model {dependency} not found in models list" + assert model_to_add is not None, msg + + node_to_add = ModelWithFixture( + model=model_to_add, + fixture=FixtureSpec( + name=model_to_add.name, + args={}, + ), + ) + + dfs(node_to_add) + elif dependency not in visited_names: + # add parent dependency + user_specified_node_to_add: ModelWithFixture | None = next( + (m for m in models_with_fixtures if m.model.name == dependency), + None, + ) + + msg = f"Model {dependency} not found in models list" + assert user_specified_node_to_add is not None, msg + + dfs(user_specified_node_to_add) + + stack.append(node) + + for node in models_with_fixtures: + if node.model.name not in visited_names: + dfs(node) + + return stack diff --git a/testing_utils/utils.py b/testing_utils/utils.py new file mode 100644 index 0000000..ec58678 --- /dev/null +++ b/testing_utils/utils.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from logging import getLogger +from typing import Any, Optional, Self + +from .models import Model, FixtureSpec, or_ +from .sort import topological_sort_and_fill + +logger = getLogger("testing-utils") + + +class BaseUtils[TTransaction, TValue](ABC): + + def __init__( + self, + name: str = "root", + parent: Optional[Self] = None, + fixtures: Optional[list[FixtureSpec]] = None, + **kwargs: Any, + ) -> None: + self._data: dict[str, Any] = {} + self._name = name + self._parent = parent + # self._is_setup = False + # self._setup_complete = False + self._created_values: dict[str, TValue] = {} + self._children: list[BaseUtils] = [] + self._parent = parent + self._fixtures = or_(fixtures, []) + self._models: list[Model] = [] + self._kwargs = kwargs + + def _find_value(self, name: str) -> TValue | None: + # if not self._is_setup: + # msg = "TestingUtils is not set up. Call setup() before using it." + # raise RuntimeError(msg) + + if name in self._created_values: + return self._created_values[name] + + if self._parent is not None: + return self._parent._find_value(name) + + in_fixture = next( + (fixture for fixture in self._fixtures if fixture.name == name), + None, + ) + + # if self._setup_complete and in_fixture is None: + # raise Exception(f"accessing {name} not in fixture") + + return None + + def _add_fixture(self, fixture: FixtureSpec) -> None: + self._fixtures.append(fixture) + + def _get_value(self, name: str) -> TValue: + val = self._find_value(name) + + if val is not None: + return val + + msg = f"Value {name} not found in created values or parent." + raise RuntimeError(msg) + + @abstractmethod + def start(self) -> TTransaction: + """ + TODO: + """ + + async def _dispatch( + self, + tx: TTransaction, + model: Model, + data: dict[str, Any], + ) -> None: + repo = getattr(self, f"_get_{model.plural_name}_repo")() + create_defaults_func = getattr(tx, f"_create_{model.name}_defaults") + defaults = create_defaults_func(**data) + value = await getattr(repo, f"create_{model.name}")(**defaults) + self._created_values[model.name] = value + + def _dispatch_sync( + self, + tx: TTransaction, + model: Model, + data: dict[str, Any], + ) -> None: + repo = getattr(self, f"_get_{model.plural_name}_repo")() + create_defaults_func = getattr(tx, f"_create_{model.name}_defaults") + defaults = create_defaults_func(**data) + value = getattr(repo, f"create_{model.name}")(**defaults) + self._created_values[model.name] = value + + async def commit(self, tx: TTransaction) -> None: + """ + TODO: + """ + await self._commit_async(tx) + + children = self._children.copy() + + while len(children) > 0: + child = children.pop(0) + await child._commit_async(tx) + children.extend(child._children) + + def commit_sync(self, tx: TTransaction) -> None: + """ + TODO: + """ + self._commit_sync(tx) + + children = self._children.copy() + + while len(children) > 0: + child = children.pop(0) + child._commit_sync(tx) + children.extend(child._children) + + async def _commit_async(self, tx: TTransaction) -> None: + models_to_create = topological_sort_and_fill( + self._models, + self._fixtures, + ) + + for model_with_fixture in models_to_create: + if self._find_value(model_with_fixture.model.name) is not None: + continue + + logger.debug( + "%s creating %s with args: %s", + self._name, + model_with_fixture.model.name, + model_with_fixture.fixture.args, + ) + + await self._dispatch(tx, model_with_fixture.model, model_with_fixture.fixture.args) + + self._fixtures = [] + + def _commit_sync(self, tx: TTransaction) -> None: + models_to_create = topological_sort_and_fill( + self._models, + self._fixtures, + ) + + for model_with_fixture in models_to_create: + if self._find_value(model_with_fixture.model.name) is not None: + continue + + logger.debug( + "%s creating %s with args: %s", + self._name, + model_with_fixture.model.name, + model_with_fixture.fixture.args, + ) + + self._dispatch_sync(tx, model_with_fixture.model, model_with_fixture.fixture.args) + + self._fixtures = [] + + @abstractmethod + def fork(self, label: str = "") -> Self: + """ + TODO: + """ + + def _fork[T: BaseUtils](self, cls: type[T], label: str = "") -> T: + name = f"{self._name}.{len(self._children)}" + + if len(label) > 0: + name += f".{label}" + + child = cls( + name=name, + parent=self, + fixtures=self._fixtures.copy(), + **self._kwargs, + ) + self._children.append(child) + return child diff --git a/tests/test_sort.py b/tests/test_sort.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..e69de29