Skip to content
Commits on Source (21)
......@@ -12,7 +12,5 @@ build/
dist/
# unit-testing, coverage, etc.
/*.xml
/*.json
.coverage*
.noseids
/*_cache/
/results/
stages:
- test
- acceptance
- build
- publish
include:
- project: dom/project-templates
file: /pipeline-templates/pre-commit.yml
- project: dom/project-templates
ref: tmp
file: /pipeline-templates/python-package.yml
image: python:3.8
variables:
PIP_CACHE_DIR: $CI_PROJECT_DIR/.cache/pip
cache:
key: all
paths:
- .cache/
- .eggs/
.python:
image: python:3.12
variables:
PIP_CACHE_DIR: $CI_PROJECT_DIR/cache/pkg
PIP_NO_COMPILE: "true"
PIP_NO_CLEAN: "true"
cache:
key: $CI_JOB_IMAGE
paths: [cache]
# Unit Tests
.unittest:
stage: test
image: python:$PY_VERSION
before_script:
- pip install -e .[test] nose2
- pip install -e .[test] coverage pytest
- mkdir results
script:
- nose2 -v -c setup.cfg
- coverage run -m pytest -v --junit-xml=results/junit.xml tests/unit
after_script:
- mv .coverage .coverage.$PY_VERSION
- mv .unittest.xml .unittest.$PY_VERSION.xml
coverage: '/^TOTAL.* ([0-9.]+\%)$/'
- mv results results.$PY_VERSION
artifacts:
when: always
paths:
- .noseids
- .coverage.*
- .unittest.$PY_VERSION.xml
- results.$PY_VERSION
reports:
junit: .unittest.$PY_VERSION.xml
junit: results.$PY_VERSION/junit.xml
unittest:3.8:
extends: .unittest
Unit Tests (Py 3.10):
extends: [.python, .unittest]
variables:
PY_VERSION: '3.8'
PY_VERSION: '3.10'
unittest:3.9:
extends: .unittest
Unit Tests (Py 3.11):
extends: [.python, .unittest]
variables:
PY_VERSION: '3.9'
PY_VERSION: '3.11'
unittest:3.10:
extends: .unittest
Unit Tests (Py 3.12):
extends: [.python, .unittest]
variables:
PY_VERSION: '3.10-rc'
PY_VERSION: '3.12'
publish:unittests:
stage: publish
Publish Unit Tests:
stage: deploy
extends: [.python]
when: always
dependencies: &unittests
- unittest:3.8
- unittest:3.9
- unittest:3.10
needs: *unittests
needs: &unittests
- Unit Tests (Py 3.10)
- Unit Tests (Py 3.11)
- Unit Tests (Py 3.12)
script:
- pip install --upgrade junit2html
- mkdir -p unittest
- python util/junit_merge.py .unittest.*.xml > .unittest.xml
- junit2html .unittest.xml unittest/index.html
- pip install --upgrade junit2html
- mkdir -p unittest
- python util/junit_merge.py results.*/junit.xml > junit.xml
- junit2html junit.xml unittest/index.html
artifacts:
when: always
paths:
- .unittest.xml
- unittest
- junit.xml
- unittest
reports:
junit: junit.xml
# Aggregate Coverage
coverage:
stage: acceptance
Aggregate Coverage:
stage: test
extends: [.python]
when: always
dependencies: *unittests
needs: *unittests
script:
- pip install --upgrade coverage
- coverage combine
- coverage report
- pip install --upgrade coverage
- coverage combine results.*/coverage.db
- coverage report
coverage: '/^TOTAL.* ([0-9.]+\%)$/'
artifacts:
when: always
paths:
- .coverage
- results
publish:coverage:
stage: publish
Publish Coverage:
stage: deploy
extends: [.python]
when: always
dependencies: [coverage]
needs: [coverage]
needs: [Aggregate Coverage]
script:
- pip install --upgrade coverage
- coverage html --fail-under=0 -d coverage
- coverage xml --fail-under=0 -o coverage/coverage.xml
- coverage html --fail-under=0 -d results/coverage.html
- coverage xml --fail-under=0 -o results/coverage.xml
artifacts:
when: always
paths:
- coverage
# Quality Assurance
Code Analysis:
stage: test
variables:
FROM_REF: $CI_DEFAULT_BRANCH
rules:
- if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH
variables:
FROM_REF: $CI_COMMIT_BEFORE_SHA
- when: always
before_script:
- pip install pre-commit
script:
- git fetch $CI_REPOSITORY_URL $FROM_REF:FROM_REF -f
- pre-commit run
--hook-stage=commit
--from-ref=FROM_REF
--to-ref=${CI_COMMIT_SHA}
Commit Graph Analysis:
stage: test
variables:
FROM_REF: $CI_DEFAULT_BRANCH
rules:
- if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH
variables:
FROM_REF: $CI_COMMIT_BEFORE_SHA
- if: $CI_PIPELINE_TRIGGERED == "merge_request_event"
before_script:
- pip install pre-commit
script:
- pre-commit run
--hook-stage=push
--from-ref=FROM_REF
--to-ref=${CI_COMMIT_SHA}
# Package publishing
Check Tag:
stage: test
rules:
- if: $CI_COMMIT_TAG =~ /^v[0-9]/
script:
- test `./setup.py --version` == ${CI_COMMIT_TAG#v}
Build Packages:
stage: build
script:
- ./setup.py bdist_wheel sdist
artifacts:
paths:
- dist
Upload Packages:
stage: publish
rules:
- if: $CI_COMMIT_TAG =~ /^v[0-9]/
dependencies:
- Build Packages
script:
- pip install twine
- TWINE_PASSWORD=$CI_JOB_TOKEN
TWINE_USERNAME=gitlab-ci-token
twine upload
--verbose
--non-interactive
--repository-url $CI_API_V4_URL/projects/$CI_PROJECT_ID/packages/pypi
dist/*
- results
......@@ -4,9 +4,10 @@ repos:
- repo: meta
hooks:
- id: check-hooks-apply
- id: check-useless-excludes
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.4.0
rev: v4.6.0
hooks:
- id: check-added-large-files
- id: check-case-conflict
......@@ -17,77 +18,84 @@ repos:
- id: debug-statements
- id: destroyed-symlinks
- id: end-of-file-fixer
stages: [commit]
stages: [commit, manual]
- id: fix-byte-order-marker
- id: fix-encoding-pragma
args: [--remove]
- id: mixed-line-ending
args: [--fix=lf]
stages: [commit, manual]
- id: trailing-whitespace
exclude_types: [markdown, plain-text]
stages: [commit]
stages: [commit, manual]
- repo: https://github.com/jorisroovers/gitlint
rev: v0.15.0
rev: v0.19.1
hooks:
- id: gitlint
- repo: https://code.kodo.org.uk/dom/pre-commit-hooks
rev: v0.6
rev: v0.6.1
hooks:
- id: check-executable-modes
- id: check-for-squash
- id: copyright-notice
args: [--min-size=1]
exclude: setup\.py
stages: [commit, manual]
- id: protect-first-parent
- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.8.0
rev: v1.10.0
hooks:
- id: python-no-eval
- id: python-no-log-warn
- id: python-use-type-annotations
- repo: https://github.com/hakancelik96/unimport
rev: 0.8.4
- repo: https://github.com/hakancelikdev/unimport
rev: 1.2.1
hooks:
- id: unimport
args: [--remove, --exclude=types.py|__init__.py]
stages: [commit, manual]
- repo: https://github.com/timothycrosley/isort
rev: 5.7.0
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
types: [python]
stages: [commit, manual]
- repo: https://github.com/asottile/add-trailing-comma
rev: v2.1.0
rev: v3.1.0
hooks:
- id: add-trailing-comma
args: [--py36-plus]
stages: [commit, manual]
- repo: https://gitlab.com/pycqa/flake8
rev: 3.8.3
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.5
hooks:
- id: flake8
args: ["--config=setup.cfg"]
additional_dependencies:
- flake8-bugbear
- flake8-docstrings
- flake8-print
- flake8-requirements
- flake8-return
- flake8-sfs
- flake8-tabs
- id: ruff
exclude: "^util/"
args: [--fix, --unsafe-fixes]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.910
rev: v1.11.1
hooks:
- id: mypy
args: [--config-file=setup.cfg]
additional_dependencies: [anyio, trio-typing]
exclude: setup\.py|test_.*
args:
- --python-version=3.10
- --follow-imports=silent
- wpa_supplicant
pass_filenames: false
additional_dependencies: &type-deps
- anyio ~=4.0
- trio-typing
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.373
hooks:
- id: pyright
args: ["--pythonversion=3.11"]
pass_filenames: false
additional_dependencies: *type-deps
line-length = 92
indent-width = 1 # Used for line length violations
[lint]
select = [
# pyflakes
# --------
# ENABLE "Undefined name %s in __all__"
"F822",
# ENABLE "Local variable %s ... referenced before assignment"
"F823",
# ENABLE "Local variable %s is assigned to but never used"
"F841",
# ENABLE "raise NotImplemented should be raise NotImplementedError"
# mypy has particular trouble with this one:
# https://github.com/python/mypy/issues/5710
"F901",
# pycodestyle
# -----------
# Warnings not considered, many are not relevant to Python ~=3.9 and will
# cause syntax errors anyway, others concern whitespace which is fixed by
# a pre-commit hook.
"E",
# mccabe
# ------
"C90",
# pydocstyle
# ----------
# Missing docstrings
"D1",
# Whitespace Issues
"D2",
# ENABLE "Use “””triple double quotes”””"
"D300",
# First line should be descriptive, imperative and capitalised
"D401", "D402", "D403", "D404",
# ENABLE "Function/Method decorated with @overload shouldn’t contain a docstring"
"D418",
# flake8-bugbear
# --------------
# The bulk of bugbear's checks are useful
"B0",
# Various others
# --------------
"UP", "BLE", "FBT", "A", "COM", "C4", "DTZ", "ISC", "LOG", "G", "PIE", "T",
"Q", "RSE", "RET", "SLF", "SLOT", "SIM", "TD", "ANN", #"FA",
# Nice to have, needs fixing in several places though...
# "EM", "TCH", "PTH", "PGH",
]
ignore = [
# pycodestyle
# -----------
# DISABLE "Indentation contains mixed spaces and tabs"
# Will cause a syntax error if critical, otherwise in docstrings it is
# sometimes nice to use different indentation for "outer" (code) indentation
# and "inner" (documentation) indentation.
"E101",
# DISABLE "Continuation line missing indentation or outdented"
# "E122",
# DISABLE "Missing whitespace around bitwise or shift operator"
"E227",
# DISABLE "missing whitespace around arithmetic operator"
"E226",
# DISABLE "Line too long"
# Prefer B950 implementation
"E501",
# DISABLE "Multiple statements on one line (colon)"
"E701",
# DISABLE "Multiple statements on one line (def)"
# Doesn't work well with @overload definitions
# "E704",
# pydocstyle
# ----------
# DISABLE "Missing docstring in magic method"
# Magic/dunder methods are well-known
"D105",
# DISABLE "Missing docstring in __init__"
# Document basic construction in the class docstring
"D107",
# DISABLE "One-line docstring should fit on one line with quotes"
# Prefer top-and-bottom style always
"D200",
# DISABLE "1 blank line required before class docstring"
"D203",
# DISABLE "Docstring should be indented with spaces, not tabs"
# Tabs, absolutely always
"D206",
# DISABLE "Multi-line docstring summary should start at the first line"
"D212",
# flake8-bugbear
# --------------
# DISABLE "Do not use mutable data structures for argument defaults [...]"
# Would be nice if could take into account use as a non-mutable type
"B006",
# DISABLE "release is an empty method in an abstract base class, [...]"
# Until abstract methods are optional, empty optional "abstract" methods
# stay
"B027",
# Use named-tuples (preferably class based) for data-only classes
# "B903",
# Replacement for E501
# "B950",
# flake8-return
# -------------
# DISABLE "missing explicit return at the end of function able to return
# non-None value"
# Mypy will report this, plugin also cannot do exhaustiveness check of match
# block, leading to false-positives.
"RET503",
# DISABLE "Missing type annotation for `%` in {method|classmethod}"
# Don't type 'self' or 'cls'
"ANN101", "ANN102",
# DISABLE "Boolean positional value in function call"
# Too many stdlib functions take a single positional-only boolean. ruff
# can't interpret function signatures to ignore these and doesn't understand
# types to allow-list methods.
"FBT003",
# DISABLE "Implicitly concatenated string literals over multiple lines"
# It sometimes looks better to do this than introduce unecessary
# parentheses.
"ISC002",
# Unfortunately a lot of single quotes strings used in this project already
"Q000",
]
[lint.per-file-ignores]
"**/__init__.py" = ["D104"]
"**/__main__.py" = ["D100", "E702"]
"**/_*.py" = ["D1"]
"examples/**.py" = ["T"]
"tests/*" = ["D1"]
"doc/*" = ["D"]
"README.md" = ["D"]
[build-system]
requires = ["setuptools>=40.8.0", "wheel"]
build-backend = "setuptools.build_meta:__legacy__"
requires = ["flit_core ~=3.8"]
build-backend = "flit_core.buildapi"
[project]
name = "wpa-supplicant-client"
version = "0.3.0"
description = "A client package for connecting to, configuring and controlling wpa_supplicant daemons"
license = {file = "LICENCE.txt"}
authors = [
{name = "Dom Sekotill", email = "dom.sekotill@kodo.org.uk"},
]
classifiers = [
"Intended Audience :: Developers",
"Operating System :: POSIX",
]
requires-python = "~=3.10"
dependencies = [
"anyio ~=4.1",
]
[project.optional-dependencies]
test = [
"trio",
]
[project.urls]
Repository = "https://code.kodo.org.uk/dom/wpa-supplicant-client"
Issues = "https://code.kodo.org.uk/dom/wpa-supplicant-client/-/issues"
[tool.flit.module]
name = "wpa_supplicant"
[tool.isort]
force_single_line = true
line_length = 92
[tool.unimport]
ignore-init = true
[tool.mypy]
allow_redefinition = true
explicit_package_bases = true
implicit_reexport = true
strict = true
warn_unreachable = true
warn_unused_configs = true
[tool.pyright]
include = ["wpa_supplicant"]
typeCheckingMode = "strict"
reportMissingModuleSource = "none"
reportRedeclaration = "none"
reportUnknownMemberType = "warning"
reportUnusedImport = "warning"
[tool.coverage.run]
data_file = "results/coverage.db"
branch = true
source = ["wpa_supplicant"]
[tool.coverage.report]
precision = 2
skip_empty = true
exclude_lines = [
"pragma: no-cover",
"if .*\\b__name__\\b",
"if .*\\bTYPE_CHECKING\\b",
"class .*(.*\\bProtocol\\b.*):",
"def __repr__",
"@overload",
"@(abc\\.)abstractmethod",
]
partial_branches = [
"pragma: no-branch",
"if .*\\b__debug__\\b",
]
[tool.coverage.json]
output = "results/coverage.json"
show_contexts = true
[tool.coverage.xml]
output = "results/coverage.xml"
[tool.coverage.html]
directory = "results/coverage"
show_contexts = true
[metadata]
name = wpa-supplicant-client
version = attr: wpa_supplicant.__version__
author = Dom Sekotill
author_email = dom.sekotill@kodo.org.uk
description = A client package for connecting to, configuring and controlling wpa_supplicant daemons
long_description = file: README.md
long_description_content_type = text/markdown
url = 'https://code.kodo.org.uk/dom/wpa-supplicant-client.git'
license = Apache-2.0
license_files =
LICENCE.txt
classifiers =
Development Status :: 2 - Pre-Alpha
Intended Audience :: Developers
License :: OSI Approved
License :: OSI Approved :: Apache Software License
Natural Language :: English
Operating System :: POSIX
Programming Language :: Python
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Typing::Typed
[options]
python_requires = >= 3.8
packages = find:
setup_requires =
setuptools >= 40.6
install_requires =
anyio ~=3.0
[options.packages.find]
include =
wpa_supplicant
wpa_supplicant.*
[options.package_data]
wpa_supplicant = py.typed
[options.extras_require]
test =
nose2[coverage_plugin]
trio
[unittest]
start-dir = tests/unit
verbose = True
plugins =
nose2.plugins.junitxml
nose2.plugins.testid
[coverage]
always-on = True
coverage = wpa_supplicant
coverage-report = term
[coverage:run]
branch = True
[coverage:report]
fail_under = 80
precision = 2
show_missing = True
omit = **/__main__.py
exclude_lines =
pragma: no cover
if __name__ == .__main__.:
def __repr__
__version__ =
@(.*\.)?abstract((static|class)?method|property)
except ImportError:\s*(...|pass)
[coverage:xml]
output = .coverage.xml
[coverage:html]
directory = .coverage.html.d
[junit-xml]
always-on = True
path = .unittest.xml
[testid]
always-on = True
[log-capture]
always-on = True
[isort]
force_single_line = true
[mypy]
strict = true
warn_unused_configs = True
warn_unreachable = true
implicit_reexport = true
[flake8]
max-line-length = 92
max-doc-length = 92
use-flake8-tabs = true
blank-lines-indent = never
indent-tabs-def = 1
format = pylint
select = C,D,E,ET,F,SFS,T,W,WT
extend-exclude =
setup.py
per-file-ignores =
setup.py: D100, E702
tests/*.py: D100, C801
**/__init__.py: D104, F401, F403
**/__main__.py: D100, E702
**/_*.py: D
ignore =
;[ Missing docstring in public method ]
; Handled by pylint, which does it better
D102
;[ Missing docstring in magic method ]
; Magic/dunder methods are well-known
D105
;[ Misisng docstring in __init__ ]
; Document basic construction in the class docstring
D107
;[ One-line docstring should fit on one line with quotes ]
; Prefer top-and-bottom style always
D200
;[ Docstring should be indented with spaces, not tabs ]
; Tabs, absolutely always
D206
;[ Use u""" for Unicode docstrings ]
; This must be for Python 2?
D302
;[ First line should end with a period ]
; First line should *NEVER* end with a period
D400
;[ First line should be in the imperative mood ]
; I like this for functions and methods, not for properties. This stands until
; pydocstyle splits a new code for properties or flake8 adds some way of
; filtering codes with line regexes like golangci-lint.
D401
;[ Line too long ]
; Prefer B950 implementation
E501
;[ multiple statements on one line (%s) ]
E701 E704
;[ unexpected number of spaces at start of statement line ]
;[ unexpected number of tabs and spaces at start of statement line ]
; Don't want spaces...
ET122 ET128
;[ Line break before binary operator ]
; Not considered current
W503
;[ Format-method string formatting ]
; Allow this style
SFS201
;[ f-string string formatting ]
; Allow this style
SFS301
include =
;[ First word of the docstring should not be This ]
D404
; flake8-bugbear plugin
; B950 is a replacement for E501
B0 B903 B950
; vim: sw=2 sts=2 expandtab
#!/usr/bin/env python3
"""Setuptools entrypoint"""
from setuptools import setup
setup()
# Copyright 2019-2021 Dom Sekotill <dom.sekotill@kodo.org.uk>
# Copyright 2019-2021, 2024 Dom Sekotill <dom.sekotill@kodo.org.uk>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -16,74 +16,27 @@
Anyio helpers for unit tests
"""
import sys
from functools import wraps
from typing import Any
from typing import Callable
from typing import Coroutine
from typing import Literal
from typing import Tuple
from typing import Union
from unittest import TestCase
from typing import Awaitable
from unittest import mock
from warnings import warn
import anyio
try:
import trio as _ # noqa
USE_TRIO = True
except ImportError:
USE_TRIO = False
Backend = Union[Literal['asyncio'], Literal['trio']]
def _delay_side_effect(delay: float) -> Awaitable[None]:
async def coro(*a: object, **k: object) -> None:
await anyio.sleep(delay)
return coro
py_version = sys.version_info[:2]
AsyncTestFunc = Callable[..., Coroutine[Any, Any, None]]
TestFunc = Callable[..., None]
def patch_connect(delay: float = 0.0) -> mock._patch:
return mock.patch(
"wpa_supplicant.client.base.connect_unix_datagram",
side_effect=_delay_side_effect(delay),
)
def with_anyio(*backends: Backend, timeout: int = 10) -> Callable[[AsyncTestFunc], TestFunc]:
"""
Create a wrapping decorator to run asynchronous test functions
"""
if not backends:
backends = ('asyncio',)
def decorator(testfunc: AsyncTestFunc) -> TestFunc:
async def test_async_wrapper(tc: TestCase, args: Tuple[mock.Mock]) -> None:
with anyio.fail_after(timeout):
await testfunc(tc, *args)
@wraps(testfunc)
def test_wrapper(tc: TestCase, *args: mock.Mock) -> None:
for backend in backends:
if backend == 'trio' and not USE_TRIO:
warn(
f"not running {testfunc.__name__} with trio; package is missing",
)
continue
with tc.subTest(f"backend: {backend}"):
anyio.run(test_async_wrapper, tc, args, backend=backend)
return test_wrapper
return decorator
class AsyncMock(mock.Mock):
"""
A Mock class that acts as a coroutine when called
"""
def __init__(self, *args: Any, delay: float = 0.0, **kwargs: Any):
mock._safe_super(AsyncMock, self).__init__(*args, **kwargs) # type: ignore
self.delay = delay
async def __call__(_mock_self, *args: Any, **kwargs: Any) -> Any:
_mock_self._mock_check_sig(*args, **kwargs)
if py_version >= (3, 8):
_mock_self._increment_mock_call(*args, **kwargs)
await anyio.sleep(_mock_self.delay)
return _mock_self._mock_call(*args, **kwargs)
def patch_send(delay: float = 0.0) -> mock._patch:
return mock.patch(
"wpa_supplicant.client.base.BaseClient.send_command",
side_effect=_delay_side_effect(delay),
)
# Copyright 2021 Dom Sekotill <dom.sekotill@kodo.org.uk>
# Copyright 2021, 2024 Dom Sekotill <dom.sekotill@kodo.org.uk>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -18,22 +18,20 @@ Test connecting and communicating with a server
import os
import sys
from unittest import TestCase
import unittest
from tests._anyio import with_anyio
from tests.integration.util import start_server
from wpa_supplicant.client import GlobalClient
class Tests(TestCase):
class Tests(unittest.IsolatedAsyncioTestCase):
"""
Tests against live wpa_suppplicant servers
The 'wpa_supplicant' executable is required in a PATH directory for these tests to work.
"""
@with_anyio('asyncio', 'trio')
async def test_connect(self):
async def test_connect(self) -> None:
"""
Test connecting to the global wpa_supplicant control socket
"""
......@@ -42,8 +40,7 @@ class Tests(TestCase):
ifaces = await client.list_interfaces()
assert len(ifaces) == 0
@with_anyio('asyncio', 'trio')
async def test_new_interface(self):
async def test_new_interface(self) -> None:
"""
Test adding a wireless interface and scanning for stations
......
# Copyright 2019-2021 Dom Sekotill <dom.sekotill@kodo.org.uk>
# Copyright 2019-2021, 2024 Dom Sekotill <dom.sekotill@kodo.org.uk>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -17,74 +17,62 @@ Test cases for wpa_supplicant.client.base.BaseClient
"""
import unittest
from unittest import mock
from unittest.mock import AsyncMock
import anyio
from tests import _anyio as anyio_mock
from tests._anyio import patch_connect
from tests._anyio import patch_send
from wpa_supplicant import errors
from wpa_supplicant.client import base
@mock.patch(
"wpa_supplicant.client.base.connect_unix_datagram",
new_callable=anyio_mock.AsyncMock,
)
@mock.patch(
"wpa_supplicant.client.base.BaseClient.send_command",
new_callable=anyio_mock.AsyncMock,
)
class ConnectTests(unittest.TestCase):
class ConnectTests(unittest.IsolatedAsyncioTestCase):
"""
Tests for the connect() method
"""
@anyio_mock.with_anyio()
async def test_connect(self, _, connect_mock):
async def test_connect(self) -> None:
"""
Check connect() calls socket.connect()
"""
async with base.BaseClient() as client:
await client.connect("foo")
with patch_connect() as connect_mock, patch_send():
async with base.BaseClient() as client:
await client.connect("foo")
connect_mock.assert_called_once_with("foo")
connect_mock.assert_awaited_once_with("foo")
@anyio_mock.with_anyio()
async def test_connect_timeout_1(self, _, connect_mock):
async def test_connect_timeout_1(self) -> None:
"""
Check a socket.connect() delay causes TimeoutError to be raised
"""
connect_mock.delay = 2
async with base.BaseClient() as client:
with self.assertRaises(TimeoutError):
await client.connect("foo")
with patch_connect(2.0), patch_send():
async with base.BaseClient() as client:
with self.assertRaises(TimeoutError):
await client.connect("foo")
@anyio_mock.with_anyio()
async def test_connect_timeout_2(self, send_mock, _):
async def test_connect_timeout_2(self) -> None:
"""
Check a send/recv delay causes a TimeoutError to be raised
"""
send_mock.delay = 2
async with base.BaseClient() as client:
with self.assertRaises(TimeoutError):
await client.connect("foo")
with patch_connect(), patch_send(2.0):
async with base.BaseClient() as client:
with self.assertRaises(TimeoutError):
await client.connect("foo")
class SendMessageTests(unittest.TestCase):
class SendMessageTests(unittest.IsolatedAsyncioTestCase):
"""
Tests for the send_command() method
"""
def setUp(self):
def setUp(self) -> None:
self.client = client = base.BaseClient()
client.sock = anyio_mock.AsyncMock(spec=anyio.abc.SocketStream)
client.sock = AsyncMock(spec=anyio.abc.SocketStream)
client.sock.send.return_value = None
assert isinstance(client.sock, anyio.abc.SocketStream)
@anyio_mock.with_anyio()
async def test_simple(self):
async def test_simple(self) -> None:
"""
Check that a response is processed after a command
"""
......@@ -92,8 +80,7 @@ class SendMessageTests(unittest.TestCase):
client.sock.receive.return_value = b"OK"
assert await client.send_command("SOME_COMMAND") is None
@anyio_mock.with_anyio()
async def test_simple_expect(self):
async def test_simple_expect(self) -> None:
"""
Check that an alternate expected response is processed
"""
......@@ -101,8 +88,7 @@ class SendMessageTests(unittest.TestCase):
client.sock.receive.return_value = b"PONG"
assert await client.send_command("PING", expect="PONG") is None
@anyio_mock.with_anyio()
async def test_simple_no_expect(self):
async def test_simple_no_expect(self) -> None:
"""
Check that an unexpected response raises an UnexpectedResponseError
"""
......@@ -113,8 +99,7 @@ class SendMessageTests(unittest.TestCase):
with self.assertRaises(errors.UnexpectedResponseError):
await client.send_command("PING", expect="PONG")
@anyio_mock.with_anyio()
async def test_simple_convert(self):
async def test_simple_convert(self) -> None:
"""
Check that a response is passed through a converter if given
"""
......@@ -127,8 +112,7 @@ class SendMessageTests(unittest.TestCase):
["FOO", "BAR", "BAZ"],
)
@anyio_mock.with_anyio()
async def test_simple_convert_over_expect(self):
async def test_simple_convert_over_expect(self) -> None:
"""
Check that 'convert' overrides 'expect'
"""
......@@ -141,8 +125,7 @@ class SendMessageTests(unittest.TestCase):
["FOO", "BAR", "BAZ"],
)
@anyio_mock.with_anyio()
async def test_simple_fail(self):
async def test_simple_fail(self) -> None:
"""
Check that a response of 'FAIL' causes CommandFailed to be raised
"""
......@@ -151,8 +134,7 @@ class SendMessageTests(unittest.TestCase):
with self.assertRaises(errors.CommandFailed):
await client.send_command("SOME_COMMAND")
@anyio_mock.with_anyio()
async def test_simple_bad_command(self):
async def test_simple_bad_command(self) -> None:
"""
Check that a response of 'UNKNOWN COMMAND' causes ValueError to be raised
"""
......@@ -161,8 +143,7 @@ class SendMessageTests(unittest.TestCase):
with self.assertRaises(ValueError):
await client.send_command("SOME_COMMAND")
@anyio_mock.with_anyio()
async def test_interleaved(self):
async def test_interleaved(self) -> None:
"""
Check that messages are processed alongside replies
"""
......@@ -175,8 +156,7 @@ class SendMessageTests(unittest.TestCase):
]
assert await client.send_command("SOME_COMMAND") is None
@anyio_mock.with_anyio()
async def test_unexpected(self):
async def test_unexpected(self) -> None:
"""
Check that unexpected replies are logged cleanly
"""
......@@ -190,8 +170,7 @@ class SendMessageTests(unittest.TestCase):
]
assert await client.event("CTRL-EVENT-EXAMPLE")
@anyio_mock.with_anyio()
async def test_unconnected(self):
async def test_unconnected(self) -> None:
"""
Check that calling send_command() on an unconnected client raises RuntimeError
"""
......@@ -200,8 +179,7 @@ class SendMessageTests(unittest.TestCase):
with self.assertRaises(RuntimeError):
await client.send_command("SOME_COMMAND")
@anyio_mock.with_anyio()
async def test_multi_task(self):
async def test_multi_task(self) -> None:
"""
Check that calling send_command() from multiple tasks works as expected
"""
......@@ -213,7 +191,7 @@ class SendMessageTests(unittest.TestCase):
(0.0, b"OK"), # Response to DETACH
])
async def recv():
async def recv() -> bytes:
delay, data = next(recv_responses)
await anyio.sleep(delay)
return data
......@@ -222,7 +200,7 @@ class SendMessageTests(unittest.TestCase):
client.sock.receive.side_effect = recv
@task_group.start_soon
async def wait_for_event():
async def wait_for_event() -> None:
self.assertTupleEqual(
await client.event("CTRL-FOO"),
(base.EventPriority.INFO, "CTRL-FOO", None),
......@@ -235,8 +213,7 @@ class SendMessageTests(unittest.TestCase):
# At this point the response to SOME_COMMAND1 is still delayed
await client.send_command("SOME_COMMAND2", expect="REPLY2")
@anyio_mock.with_anyio()
async def test_multi_task_decode_error(self):
async def test_multi_task_decode_error(self) -> None:
"""
Check that decode errors closes the socket and causes all tasks to raise EOFError
"""
......@@ -251,27 +228,26 @@ class SendMessageTests(unittest.TestCase):
client.sock.receive.side_effect = recv_responses
@task_group.start_soon
async def wait_for_event():
async def wait_for_event() -> None:
with self.assertRaises(anyio.ClosedResourceError):
await client.event("CTRL-FOO"),
await client.event("CTRL-FOO")
await anyio.sleep(0.1) # Ensure send_command("ATTACH") has been sent
with self.assertRaises(anyio.ClosedResourceError):
await client.send_command("SOME_COMMAND", expect="REPLY")
class EventTests(unittest.TestCase):
class EventTests(unittest.IsolatedAsyncioTestCase):
"""
Tests for the event() method
"""
def setUp(self):
def setUp(self) -> None:
self.client = client = base.BaseClient()
client.sock = anyio_mock.AsyncMock()
client.sock = AsyncMock()
client.sock.send.return_value = None
@anyio_mock.with_anyio()
async def test_simple(self):
async def test_simple(self) -> None:
"""
Check that an awaited message is returned when is arrives
"""
......@@ -287,8 +263,7 @@ class EventTests(unittest.TestCase):
assert evt == "CTRL-EVENT-EXAMPLE"
assert args is None
@anyio_mock.with_anyio()
async def test_multiple(self):
async def test_multiple(self) -> None:
"""
Check that an awaited messages is returned when it arrives between others
"""
......@@ -307,8 +282,7 @@ class EventTests(unittest.TestCase):
assert evt == "CTRL-EVENT-EXAMPLE"
assert args is None
@anyio_mock.with_anyio()
async def test_wait_multiple(self):
async def test_wait_multiple(self) -> None:
"""
Check that the first of several awaited events is returned
"""
......@@ -330,8 +304,7 @@ class EventTests(unittest.TestCase):
assert evt == "CTRL-EVENT-EXAMPLE3"
assert args is None
@anyio_mock.with_anyio()
async def test_interleaved(self):
async def test_interleaved(self) -> None:
"""
Check that messages are processed as well as replies
"""
......@@ -357,8 +330,7 @@ class EventTests(unittest.TestCase):
assert await client.send_command("SOME_COMMAND", expect="FOO") is None
@anyio_mock.with_anyio()
async def test_unconnected(self):
async def test_unconnected(self) -> None:
"""
Check that calling event() on an unconnected client raises RuntimeError
"""
......
# Copyright 2019-2021 Dom Sekotill <dom.sekotill@kodo.org.uk>
# Copyright 2019-2021, 2024 Dom Sekotill <dom.sekotill@kodo.org.uk>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -18,25 +18,24 @@ Test cases for wpa_supplicant.client.GlobalClient
import pathlib
import unittest
from unittest.mock import AsyncMock
from unittest.mock import patch
from tests import _anyio as anyio_mock
from wpa_supplicant.client import GlobalClient
from wpa_supplicant.client import InterfaceClient
class InterfaceMethodsTests(unittest.TestCase):
class InterfaceMethodsTests(unittest.IsolatedAsyncioTestCase):
"""
Tests for the *_interface(s?) methods
"""
def setUp(self):
def setUp(self) -> None:
self.client = client = GlobalClient()
client.sock = anyio_mock.AsyncMock()
client.sock = AsyncMock()
client.sock.send.return_value = None
@anyio_mock.with_anyio()
async def test_connect(self):
async def test_connect(self) -> None:
"""
Check that connect sets ctrl_dir
"""
......@@ -45,7 +44,7 @@ class InterfaceMethodsTests(unittest.TestCase):
with patch(
"wpa_supplicant.client.base.BaseClient.connect",
new_callable=anyio_mock.AsyncMock,
new_callable=AsyncMock,
):
await client1.connect("/tmp/foo/bar")
await client2.connect(pathlib.Path("/tmp/foo/bar"))
......@@ -56,8 +55,7 @@ class InterfaceMethodsTests(unittest.TestCase):
assert client1.ctrl_dir == pathlib.Path("/tmp/foo")
assert client2.ctrl_dir == pathlib.Path("/tmp/foo")
@anyio_mock.with_anyio()
async def test_list_interfaces(self):
async def test_list_interfaces(self) -> None:
"""
Check list_interfaces() processes lines of names in a list
"""
......@@ -76,8 +74,7 @@ class InterfaceMethodsTests(unittest.TestCase):
client.sock.send.assert_called_once_with(b"INTERFACES")
@anyio_mock.with_anyio()
async def test_add_interface(self):
async def test_add_interface(self) -> None:
"""
Check add_interface() sends the correct arguments
"""
......@@ -93,10 +90,9 @@ class InterfaceMethodsTests(unittest.TestCase):
@patch(
"wpa_supplicant.client.interfaces.InterfaceClient.connect",
new_callable=anyio_mock.AsyncMock,
new_callable=AsyncMock,
)
@anyio_mock.with_anyio()
async def test_connect_interface(self, connect_mock):
async def test_connect_interface(self, connect_mock: AsyncMock) -> None:
"""
Check connect_interface() returns a connected InterfaceClient
"""
......@@ -116,10 +112,9 @@ class InterfaceMethodsTests(unittest.TestCase):
@patch(
"wpa_supplicant.client.interfaces.InterfaceClient.connect",
new_callable=anyio_mock.AsyncMock,
new_callable=AsyncMock,
)
@anyio_mock.with_anyio()
async def test_connect_interface_with_add(self, connect_mock):
async def test_connect_interface_with_add(self, connect_mock: AsyncMock) -> None:
"""
Check connect_interface() adds the interface when not already managed
"""
......@@ -140,8 +135,7 @@ class InterfaceMethodsTests(unittest.TestCase):
self.assertTupleEqual(args[0][0], (b"INTERFACES",))
assert args[1][0][0].startswith(b"INTERFACE_ADD enp1s0\t")
@anyio_mock.with_anyio()
async def test_unconnected(self):
async def test_unconnected(self) -> None:
"""
Check that calling add_interface() on an unconnected client raises RuntimeError
......
# Copyright 2019-2021 Dom Sekotill <dom.sekotill@kodo.org.uk>
# Copyright 2019-2021, 2024 Dom Sekotill <dom.sekotill@kodo.org.uk>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -17,26 +17,27 @@ Test cases for wpa_supplicant.client.interfaces.InterfaceClient
"""
import unittest
from collections.abc import Iterator
from contextlib import contextmanager
from unittest.mock import AsyncMock
from unittest.mock import call
from tests import _anyio as anyio_mock
from wpa_supplicant import config
from wpa_supplicant.client import interfaces
class MethodsTests(unittest.TestCase):
class MethodsTests(unittest.IsolatedAsyncioTestCase):
"""
Tests for InterfaceClient methods
"""
def setUp(self):
def setUp(self) -> None:
self.client = client = interfaces.InterfaceClient()
client.sock = anyio_mock.AsyncMock()
client.sock = AsyncMock()
client.sock.send.return_value = None
@contextmanager
def subTest(self, *args, reset=[], **kwargs):
def subTest(self, *args: object, reset: list[AsyncMock] = [], **kwargs: object) -> Iterator[None]:
with super().subTest(*args, **kwargs):
try:
yield
......@@ -44,8 +45,7 @@ class MethodsTests(unittest.TestCase):
for mock in reset:
mock.reset_mock()
@anyio_mock.with_anyio()
async def test_scan(self):
async def test_scan(self) -> None:
"""
Check that a scan command waits for a notification then terminates correctly
"""
......@@ -64,8 +64,7 @@ class MethodsTests(unittest.TestCase):
self.assertIsInstance(bss, dict)
self.assertIn("good", bss)
@anyio_mock.with_anyio()
async def test_set_network(self):
async def test_set_network(self) -> None:
"""
Check that set_network sends values to the daemon and raises TypeError for bad types
"""
......@@ -105,8 +104,7 @@ class MethodsTests(unittest.TestCase):
self.assertRaises(TypeError):
await client.set_network("0", "key_mgmt", 1)
@anyio_mock.with_anyio()
async def test_add_network(self):
async def test_add_network(self) -> None:
"""
Check that add_network adds a new network and configures it
"""
......
"""
Async control of WPA-Supplicant from a Python process
# Copyright 2019-2021 Dom Sekotill <dom.sekotill@kodo.org.uk>
# Copyright 2019-2021, 2024 Dom Sekotill <dom.sekotill@kodo.org.uk>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -15,5 +15,3 @@ Async control of WPA-Supplicant from a Python process
# See the License for the specific language governing permissions and
# limitations under the License.
"""
__version__ = "0.3.0"
# Copyright 2021 Dom Sekotill <dom.sekotill@kodo.org.uk>
# Copyright 2021, 2024 Dom Sekotill <dom.sekotill@kodo.org.uk>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -19,47 +19,15 @@ Work-arounds for lack of AF_UNIX datagram socket support in Anyio
from __future__ import annotations
import errno
import os
import socket
import tempfile
from contextlib import suppress
from os import PathLike
from typing import Any
from typing import Callable
from typing import Coroutine
from typing import Dict
from typing import Protocol
from typing import Union
from typing import cast
import sniffio
from anyio import create_connected_unix_datagram_socket
from anyio.abc import ConnectedUNIXDatagramSocket as DatagramSocket
ConnectorFn = Callable[[str, str], Coroutine[Any, Any, 'DatagramSocket']]
connectors: Dict[str, ConnectorFn] = {}
class DatagramSocket(Protocol):
@property
def _raw_socket(self) -> socket.socket: ...
async def aclose(self) -> None: ...
async def receive(self) -> bytes: ...
async def send(self, item: bytes) -> None: ...
class ConnectedUNIXMixin:
async def aclose(self: DatagramSocket) -> None:
path = self._raw_socket.getsockname()
await super().aclose() # type: ignore # Mypy doesn't handle super() well in mixins
os.unlink(path)
async def connect_unix_datagram(path: Union[str, PathLike[str]]) -> DatagramSocket:
async def connect_unix_datagram(path: str | PathLike[str]) -> DatagramSocket:
"""
Return an AnyIO socket connected to a Unix datagram socket
......@@ -68,82 +36,7 @@ async def connect_unix_datagram(path: Union[str, PathLike[str]]) -> DatagramSock
for _ in range(10):
fname = tempfile.mktemp(suffix=".sock", prefix="wpa_ctrl.")
with suppress(FileExistsError):
async_lib = sniffio.current_async_library()
connector = connectors[async_lib]
return await connector(fname, os.fspath(path))
return await create_connected_unix_datagram_socket(path, local_path=fname)
raise FileExistsError(
errno.EEXIST, "No usable temporary filename found",
)
try:
import trio
except ImportError: ...
else:
from anyio._backends import _trio
class TrioConnectedUNIXSocket(ConnectedUNIXMixin, _trio.ConnectedUDPSocket):
...
async def trio_connect_unix_datagram(
local_path: str,
remote_path: str,
) -> TrioConnectedUNIXSocket:
sock = trio.socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
await sock.bind(local_path)
try:
await sock.connect(remote_path)
except BaseException: # pragma: no cover
sock.close()
raise
else:
return TrioConnectedUNIXSocket(sock)
connectors['trio'] = trio_connect_unix_datagram
# asyncio is in the stdlib, but lets make the layout match trio 😉
try:
import asyncio
except ImportError: ...
else:
from anyio._backends import _asyncio
class AsyncioConnectedUNIXSocket(ConnectedUNIXMixin, _asyncio.ConnectedUDPSocket):
...
async def asyncio_connect_unix_datagram(
local_path: str,
remote_path: str,
) -> AsyncioConnectedUNIXSocket:
await asyncio.sleep(0.0)
loop = asyncio.get_running_loop()
sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
sock.setblocking(False)
sock.bind(local_path)
while True:
try:
sock.connect(remote_path)
except BlockingIOError:
future: asyncio.Future[None] = asyncio.Future()
loop.add_writer(sock, future.set_result, None)
future.add_done_callback(lambda _: loop.remove_writer(sock))
await future
except BaseException:
sock.close()
raise
else:
break
transport_, protocol_ = await asyncio.get_running_loop().create_datagram_endpoint(
_asyncio.DatagramProtocol,
sock=sock,
)
transport = cast(asyncio.DatagramTransport, transport_)
protocol = cast(_asyncio.DatagramProtocol, protocol_)
if protocol.exception:
transport.close()
raise protocol.exception
return AsyncioConnectedUNIXSocket(transport, protocol)
connectors['asyncio'] = asyncio_connect_unix_datagram
# Copyright 2019-2021 Dom Sekotill <dom.sekotill@kodo.org.uk>
# Copyright 2019-2021, 2024 Dom Sekotill <dom.sekotill@kodo.org.uk>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -16,14 +16,7 @@
WPA-Supplicant client classes
"""
from ._global import GlobalClient
from .base import BaseClient
from ._global import GlobalClient as GlobalClient
from .base import BaseClient as BaseClient
from .consts import *
from .consts import __all__ as _consts_names
from .interfaces import InterfaceClient
__all__ = _consts_names + (
'BaseClient',
'GlobalClient',
'InterfaceClient',
)
from .interfaces import InterfaceClient as InterfaceClient
# Copyright 2019-2021 Dom Sekotill <dom.sekotill@kodo.org.uk>
# Copyright 2019-2021, 2024 Dom Sekotill <dom.sekotill@kodo.org.uk>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -20,7 +20,6 @@ from __future__ import annotations
import pathlib
from os import PathLike
from typing import Set
from . import consts
from .base import BaseClient
......@@ -40,7 +39,7 @@ class GlobalClient(BaseClient):
await super().connect(path)
self.ctrl_dir = path.parent
async def list_interfaces(self) -> Set[str]:
async def list_interfaces(self) -> set[str]:
"""
Return a set of the interfaces currently managed by the daemon
"""
......@@ -52,14 +51,10 @@ class GlobalClient(BaseClient):
"""
Add a network interface to the daemon's control interfaces
"""
if self.ctrl_dir:
ctrl_iface = f"DIR={self.ctrl_dir} GROUP={self.ctrl_dir.group()}"
else:
# RuntimeError should be raised by send_command() as connect() does not appear
# to have been called; set ctrl_iface to any string
ctrl_iface = ""
await self.send_command(
consts.COMMAND_INTERFACE_ADD, ifname, "", driver, ctrl_iface, driver_param,
consts.COMMAND_INTERFACE_ADD, ifname, "", driver,
f"DIR={self.ctrl_dir} GROUP={self.ctrl_dir.group()}" if self.ctrl_dir else "",
driver_param,
)
assert self.ctrl_dir is not None, \
"RuntimeError should be raised for sends on unconnected clients; " \
......
# Copyright 2019-2021 Dom Sekotill <dom.sekotill@kodo.org.uk>
# Copyright 2019-2021, 2024 Dom Sekotill <dom.sekotill@kodo.org.uk>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,7 +13,7 @@
# limitations under the License.
"""
This module provides a base WPA-Supplicant client implementation
Base implementation for WPA-Supplicant client classes
"""
from __future__ import annotations
......@@ -21,17 +21,15 @@ from __future__ import annotations
import enum
import logging
import os
import sys
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from re import compile as regex
from types import TracebackType as Traceback
from typing import Any
from typing import AsyncContextManager
from typing import Callable
from typing import Dict
from typing import Optional
from typing import Tuple
from typing import Type
from typing import TypeVar
from typing import Union
from typing import overload
import anyio
......@@ -53,7 +51,7 @@ class EventPriority(enum.IntEnum):
Event Message priorities
"""
def get_logger_level(self, *, _mapping: Dict[EventPriority, int] = {}) -> int:
def get_logger_level(self, *, _mapping: dict[EventPriority, int] = {}) -> int:
"""
Return a logging level matching the `wpa_supplicant` priority level
"""
......@@ -93,15 +91,15 @@ class BaseClient:
event_regex = regex(r"<([0-9]+)>(?:((?:CTRL|WPS|AP|P2P)-[A-Z0-9-]+)(?:\s|$))?(.+)?")
def __init__(self, *, logger: Optional[logging.Logger] = None):
def __init__(self, *, logger: logging.Logger | None = None) -> None:
self.logger = logger or logging.getLogger(__package__)
self.ctrl_dir = None
self.sock: Optional[DatagramSocket] = None
self.sock: DatagramSocket | None = None
self._lock = anyio.Lock()
self._condition = anyio.Condition()
self._handler_active = False
self._reply: Union[_ReplyState, str] = _ReplyState.NOTHING
self._event: Optional[EventInfo]
self._reply: _ReplyState | str = _ReplyState.NOTHING
self._event: EventInfo | None
self._eventcount = 0
async def __aenter__(self) -> BaseClient:
......@@ -109,9 +107,9 @@ class BaseClient:
async def __aexit__(
self,
_et: Optional[Type[BaseException]],
_e: Optional[BaseException],
_tb: Optional[Traceback],
_et: type[BaseException] | None,
_e: BaseException | None,
_tb: Traceback | None,
) -> None:
await self.disconnect()
......@@ -159,8 +157,8 @@ class BaseClient:
*args: str,
separator: str = consts.SEPARATOR_TAB,
expect: str = consts.RESPONSE_OK,
convert: Optional[Callable[[str], T]] = None,
) -> Optional[T]:
convert: Callable[[str], T] | None = None,
) -> T | None:
"""
Send a message and await a response
......@@ -215,11 +213,28 @@ class BaseClient:
)
return None
def attach(self) -> AsyncContextManager[None]:
@asynccontextmanager
async def attach(self) -> AsyncIterator[None]:
"""
Return a context manager that handles attaching to the daemon's message queue
"""
return self._AttachContext(self)
assert self._eventcount >= 0
self._eventcount += 1
if self._eventcount == 1:
await self.send_command(consts.COMMAND_ATTACH)
try:
yield
except:
if __debug__:
exc_type, *_ = sys.exc_info()
assert exc_type is not None
self.logger.debug("Detaching due to %s", exc_type.__name__)
raise
finally:
assert self._eventcount > 0
self._eventcount -= 1
if self._eventcount == 0:
await self.send_command(consts.COMMAND_DETACH)
async def event(self, *events: str) -> EventInfo:
"""
......@@ -262,45 +277,25 @@ class BaseClient:
raise anyio.ClosedResourceError
self.logger.debug("Received: %s", repr(msg))
match = self.event_regex.match(msg)
# If matched, it is an event
if match:
prio_, name, msg = match.groups()
prio = EventPriority(int(prio_))
# If it's not an event, check whether a reply to a sent message is expected
elif self._reply is not _ReplyState.AWAITING:
self.logger.warning("Unexpected response message: %s", msg)
return
else:
self._reply = msg
return
# Unnamed events are just for logging
if not name:
self.logger.log(prio.get_logger_level(), msg)
return
self._event = (prio, name, msg or None)
class _AttachContext:
def __init__(self, client: BaseClient):
self.client = client
async def __aenter__(self) -> None:
client = self.client
assert client._eventcount >= 0
if client._eventcount == 0:
await client.send_command(consts.COMMAND_ATTACH)
client._eventcount += 1
async def __aexit__(self, *exc_info: Any) -> None:
client = self.client
assert client._eventcount > 0
client._eventcount -= 1
if client._eventcount == 0:
if __debug__: # On it's own for compiler optimisation
if exc_info[0]:
client.logger.debug(f"Detaching due to {exc_info[0].__name__}")
await client.send_command(consts.COMMAND_DETACH)
match self._parse_message(msg):
case str(msg):
if self._reply is _ReplyState.AWAITING:
self._reply = msg
else:
self.logger.warning("Unexpected response message: %s", msg)
case [prio, name, message] if name is None:
# Unnamed events are just for logging
assert message is not None, "empty log message received"
self.logger.log(prio.get_logger_level(), message)
case [prio, str(name), message]:
self._event = (prio, name, message)
case _: # pragma: no-cover
raise AssertionError("unexpected return from BaseClient._parse_message()")
@classmethod
def _parse_message(cls, message: str) -> tuple[EventPriority, str|None, str|None] | str:
if not (rematch := cls.event_regex.match(message)):
return message
prio_, name, msg = rematch.groups()
return EventPriority(int(prio_)), name, msg
# Copyright 2019-2021 Dom Sekotill <dom.sekotill@kodo.org.uk>
# Copyright 2019-2021, 2024 Dom Sekotill <dom.sekotill@kodo.org.uk>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -20,7 +20,6 @@ from __future__ import annotations
from itertools import count
from os import PathLike
from typing import Any
from typing import AsyncGenerator
from typing import Dict
......@@ -39,12 +38,15 @@ class InterfaceClient(BaseClient):
name = None
async def connect(self, path: PathLike[str]) -> None:
"""
Connect to an interface UNIX port
"""
await super().connect(path)
self.name = await self.send_command(consts.COMMAND_IFNAME, convert=str)
async def scan(self) -> AsyncGenerator[StringMap, None]:
"""
Iteratively produces the details of all detectable IEEE 802.11 BSS
Yield the details of all detectable IEEE 802.11 BSS
(WiFi Access Points to you and me)
"""
......@@ -59,7 +61,7 @@ class InterfaceClient(BaseClient):
return
yield bss
async def add_network(self, configuration: Dict[str, Any]) -> int:
async def add_network(self, configuration: dict[str, object]) -> int:
"""Add a new network configuration"""
netid = await self.send_command(consts.COMMAND_ADD_NETWORK, convert=str)
for var, val in configuration.items():
......@@ -67,7 +69,7 @@ class InterfaceClient(BaseClient):
await self.send_command(consts.COMMAND_ENABLE_NETWORK, netid)
return int(netid)
async def set_network(self, netid: str, variable: str, value: Any) -> None:
async def set_network(self, netid: str, variable: str, value: object) -> None:
"""Set a network configuration option"""
if not isinstance(value, config.get_type(variable)):
raise TypeError(f"Wrong type for {variable}: {value!r}")
......@@ -84,4 +86,4 @@ def _kv2dict(keyvalues: str) -> StringMap:
"""
Convert a list of line-terminated "key=value" substrings into a dictionary
"""
return dict(kv.split("=", 1) for kv in keyvalues.splitlines()) # type: ignore
return dict(kv.split("=", 1) for kv in keyvalues.splitlines())
# Copyright 2019-2021 Dom Sekotill <dom.sekotill@kodo.org.uk>
# Copyright 2019-2021, 2024 Dom Sekotill <dom.sekotill@kodo.org.uk>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -18,7 +18,6 @@ Helpers for network configuration
from enum import Enum
from enum import auto
from typing import Any
from typing import Callable
from typing import Dict
from typing import Optional
......@@ -67,7 +66,7 @@ def get_type(variable: str) -> type:
class _UnknownTypeMeta(type):
def __instancecheck__(cls, instance: Any) -> bool:
def __instancecheck__(cls, instance: object) -> bool:
return isinstance(instance, (str, int))
......@@ -82,7 +81,7 @@ class ConfigEnum(Enum):
return str(self.value)
@staticmethod
def _generate_next_value_(name: str, *_: Any) -> str:
def _generate_next_value_(name: str, start: object, count: object, last_values: object) -> str:
return name.replace("_", "-")
......