diff --git a/CHANGELOG.md b/CHANGELOG.md index 1abd8df..2b28679 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,12 @@ You can find our backwards-compatibility policy [here](https://github.com/hynek/ [#73](https://github.com/hynek/svcs/pull/73) +### Fixed + +- `Container.aget()` now also enters and exists synchronous context managers. + [#93](https://github.com/hynek/svcs/pull/93) + + ## [24.1.0](https://github.com/hynek/svcs/compare/23.21.0...24.1.0) - 2024-01-25 ### Fixed diff --git a/pyproject.toml b/pyproject.toml index 56167fb..e0c420a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -182,6 +182,10 @@ ignore_missing_imports = true module = "tests.*" ignore_errors = true +[[tool.mypy.overrides]] +module = "conftest" +ignore_errors = true + [[tool.mypy.overrides]] module = "tests.typing.*" ignore_errors = false diff --git a/src/svcs/_core.py b/src/svcs/_core.py index 42b44f4..ebef048 100644 --- a/src/svcs/_core.py +++ b/src/svcs/_core.py @@ -995,6 +995,9 @@ async def aget(self, *svc_types: type) -> object: Also works with synchronous services, so in an async application, just use this. + + .. versionchanged:: 24.2.0 + Synchronous context managers are now entered/exited, too. """ rv = [] for svc_type in svc_types: @@ -1008,6 +1011,9 @@ async def aget(self, *svc_types: type) -> object: svc = await svc.__aenter__() elif isawaitable(svc): svc = await svc + elif enter and isinstance(svc, AbstractContextManager): + self._on_close.append((name, svc)) + svc = svc.__enter__() self._instantiated[svc_type] = svc diff --git a/tests/test_container.py b/tests/test_container.py index bc38204..6e9f269 100644 --- a/tests/test_container.py +++ b/tests/test_container.py @@ -109,6 +109,26 @@ def scope(): "Container was garbage-collected with pending cleanups.", ) == recwarn.list[0].message.args + @pytest.mark.asyncio() + async def test_aget_enters_sync_contextmanagers(self, container): + """ + aget enters (and exits) synchronous context managers. + """ + is_closed = False + + def factory(): + yield 42 + + nonlocal is_closed + is_closed = True + + container.registry.register_factory(int, factory) + + async with container: + assert 42 == await container.aget(int) + + assert is_closed + class TestServicePing: def test_ping(self, registry, container, close_me):