diff --git a/aioinject/testing/testcontainer.py b/aioinject/testing/testcontainer.py index 5991cc6..25e5a42 100644 --- a/aioinject/testing/testcontainer.py +++ b/aioinject/testing/testcontainer.py @@ -1,6 +1,12 @@ from __future__ import annotations -from collections.abc import Iterator +from collections.abc import AsyncIterator, Iterator +from contextlib import ( + AsyncExitStack, + ExitStack, + asynccontextmanager, + contextmanager, +) from types import TracebackType from typing import TYPE_CHECKING, Any @@ -123,5 +129,22 @@ class TestContainer: def __init__(self, container: Container | SyncContainer) -> None: self.container = container - def override(self, provider: Provider[Any]) -> _Override: - return _Override(self.container, provider) + @asynccontextmanager + async def override(self, *providers: Provider[Any]) -> AsyncIterator[None]: + overrides = [ + _Override(self.container, provider) for provider in providers + ] + async with AsyncExitStack() as exitstack: + for override in overrides: + await exitstack.enter_async_context(override) + yield + + @contextmanager + def override_sync(self, *providers: Provider[Any]) -> Iterator[None]: + overrides = [ + _Override(self.container, provider) for provider in providers + ] + with ExitStack() as exitstack: + for override in overrides: + exitstack.enter_context(override) + yield diff --git a/tests/testing/test_testcontainer.py b/tests/testing/test_testcontainer.py index af3c4a7..26b769e 100644 --- a/tests/testing/test_testcontainer.py +++ b/tests/testing/test_testcontainer.py @@ -25,7 +25,7 @@ async def test_override() -> None: async with ( container.context() as ctx, ): - with testcontainer.override(Object(override)): + with testcontainer.override_sync(Object(override)): assert await ctx.resolve(int) is override async with container.context() as ctx: @@ -153,3 +153,30 @@ class C: c = await container.root.resolve(C) assert c is not c_mocked assert c.b.a is not a_mock + + +async def test_multiple_overrides() -> None: + class A: + pass + + @dataclasses.dataclass + class B: + a: A + + @dataclasses.dataclass + class C: + b: B + + container = Container() + container.register(Singleton(A), Singleton(B), Singleton(C)) + testcontainer = TestContainer(container) + + a_mock = A() + b_mock = B(a_mock) + async with testcontainer.override(Object(a_mock), Object(b_mock)): + c_mocked = await container.root.resolve(C) + + c = await container.root.resolve(C) + assert c is not c_mocked + assert c.b is not b_mock + assert c.b.a is not a_mock diff --git a/uv.lock b/uv.lock index 9c86b2c..e527a8b 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10" [[package]] @@ -125,7 +125,7 @@ wheels = [ [[package]] name = "aioinject" -version = "1.10.1" +version = "1.10.2" source = { editable = "." } dependencies = [ { name = "typing-extensions" },