diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 144f63a35443ded2a61603999896d549e242088f..e59849eebb5685b172500a5cf21e9756ae0ad521 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -86,10 +86,10 @@ repos: rev: v1.15.0 hooks: - id: mypy - args: [kilter/service, tests] + args: [kilter/service, tests, --python-version=3.11] pass_filenames: false additional_dependencies: - - anyio ~=3.1 + - anyio ~=4.0 - kilter.protocol ~=0.6.0 - sphinx - trio-typing diff --git a/kilter/service/runner.py b/kilter/service/runner.py index 0a89fce645f9e2776c9c3c706b28796da32d649b..99334c723fe457d9c300c96d823bf93aac9b22ee 100644 --- a/kilter/service/runner.py +++ b/kilter/service/runner.py @@ -405,6 +405,6 @@ class _TaskRunner: def _make_message_channel() -> tuple[MessageChannel, MessageChannel]: - lsend, rrecv = anyio.create_memory_object_stream(1, Message) # type: ignore - rsend, lrecv = anyio.create_memory_object_stream(1, Message) # type: ignore + lsend, rrecv = anyio.create_memory_object_stream[Message](1) + rsend, lrecv = anyio.create_memory_object_stream[Message](1) return StapledObjectStream(lsend, lrecv), StapledObjectStream(rsend, rrecv) diff --git a/pyproject.toml b/pyproject.toml index 9d0474501a027447d585c5819a3bfaeb4bbf6d15..89737ef2404aa5a2458ffe6e7640f67c5e4c4002 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,9 +11,9 @@ license = {file = "LICENCE.txt"} readme = "README.md" dynamic = ["version", "description"] -requires-python = "~=3.10" +requires-python = "~=3.11" dependencies = [ - "anyio ~=3.0", + "anyio ~=4.0", "async-generator ~=1.2", "kilter.protocol ~=0.6.0", "typing-extensions ~=4.0", @@ -27,7 +27,7 @@ classifiers = [ [project.optional-dependencies] tests = [ - "trio <0.22", # Until anyio supports BaseExceptionGroup + "trio", ] docs = [ "sphinx ~=5.0", diff --git a/tests/__init__.py b/tests/__init__.py index 231d607463e5956f4e0c4c81b13c596b8ee9c3e0..99857bed0ca782ab1652bcbeb238f7770faa2336 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -2,22 +2,39 @@ A package of tests for kilter.service modules """ +from __future__ import annotations + import functools import os from collections.abc import Callable from collections.abc import Coroutine +from collections.abc import Iterator +from contextlib import contextmanager from inspect import iscoroutinefunction +from typing import TYPE_CHECKING from typing import Any +from typing import Protocol +from typing import Self +from typing import TypeVar from unittest import TestCase import trio +E = TypeVar("E", bound=BaseException) + SyncTest = Callable[[TestCase], None] AsyncTest = Callable[[TestCase], Coroutine[Any, Any, None]] LIMIT_SCALE_FACTOR = float(os.environ.get("LIMIT_SCALE_FACTOR", 1)) +if TYPE_CHECKING: + class AssertRaisesContext(Protocol[E]): # noqa: D101 + exception: E + expected: type[BaseException] | tuple[type[BaseException], ...] + msg: str|None + + class AsyncTestCase(TestCase): """ A variation of `unittest.TestCase` with support for awaitable (async) test functions @@ -30,6 +47,28 @@ class AsyncTestCase(TestCase): if name.startswith("test_") and iscoroutinefunction(value): setattr(cls, name, _syncwrap(value, time_limit * LIMIT_SCALE_FACTOR)) + @contextmanager + def assertRaises( # type: ignore[override] + self, + expected_exception: type[E]|tuple[type[E], ...], + *, + msg: str|None = None, + ) -> Iterator[AssertRaisesContext[E]]: + """ + Return a context manager that asserts a given exception is raised with the context + + Extends the base assertRaises with support for ExceptionGroups. If at most one leaf + exception is raised in the group and it matches the expected type, it will be + treated as a successful failure. + """ + with super().assertRaises(expected_exception, msg=msg) as context: + try: + yield context + except* expected_exception as grp: + exc = [*_leaf_exc(grp)] + assert len(exc) == 1 + raise exc[0] from grp + def _syncwrap(test: AsyncTest, time_limit: float) -> SyncTest: @functools.wraps(test) @@ -41,3 +80,11 @@ def _syncwrap(test: AsyncTest, time_limit: float) -> SyncTest: raise TimeoutError trio.run(limiter) return wrap + + +def _leaf_exc(group: BaseExceptionGroup) -> Iterator[BaseException]: + for exc in group.exceptions: + if isinstance(exc, BaseExceptionGroup): + yield from _leaf_exc(exc) + else: + yield exc diff --git a/tests/mock_stream.py b/tests/mock_stream.py index c2ccca59452fa765aa6a540a4549b9782ae13d25..05458997b7932178bdd185b2eecb5edf45a6b8ec 100644 --- a/tests/mock_stream.py +++ b/tests/mock_stream.py @@ -50,8 +50,8 @@ class MockMessageStream: self.closed = False async def __aenter__(self) -> Self: - send_obj, recv_bytes = anyio.create_memory_object_stream(5, bytes) - send_bytes, recv_obj = anyio.create_memory_object_stream(5, bytes) + send_obj, recv_bytes = anyio.create_memory_object_stream[bytes](5) + send_bytes, recv_obj = anyio.create_memory_object_stream[bytes](5) self._stream = StapledObjectStream(send_obj, recv_obj) self.peer_stream = StapledByteStream(