Skip to content
Commits on Source (21)
...@@ -12,7 +12,5 @@ build/ ...@@ -12,7 +12,5 @@ build/
dist/ dist/
# unit-testing, coverage, etc. # unit-testing, coverage, etc.
/*.xml /*_cache/
/*.json /results/
.coverage*
.noseids
stages: include:
- test - project: dom/project-templates
- acceptance file: /pipeline-templates/pre-commit.yml
- build - project: dom/project-templates
- publish ref: tmp
file: /pipeline-templates/python-package.yml
image: python:3.8
variables:
PIP_CACHE_DIR: $CI_PROJECT_DIR/.cache/pip
cache:
key: all
paths:
- .cache/
- .eggs/
.python:
image: python:3.12
variables:
PIP_CACHE_DIR: $CI_PROJECT_DIR/cache/pkg
PIP_NO_COMPILE: "true"
PIP_NO_CLEAN: "true"
cache:
key: $CI_JOB_IMAGE
paths: [cache]
# Unit Tests
.unittest: .unittest:
stage: test stage: test
image: python:$PY_VERSION image: python:$PY_VERSION
before_script: before_script:
- pip install -e .[test] nose2 - pip install -e .[test] coverage pytest
- mkdir results
script: script:
- nose2 -v -c setup.cfg - coverage run -m pytest -v --junit-xml=results/junit.xml tests/unit
after_script: after_script:
- mv .coverage .coverage.$PY_VERSION - mv results results.$PY_VERSION
- mv .unittest.xml .unittest.$PY_VERSION.xml
coverage: '/^TOTAL.* ([0-9.]+\%)$/'
artifacts: artifacts:
when: always when: always
paths: paths:
- .noseids - results.$PY_VERSION
- .coverage.*
- .unittest.$PY_VERSION.xml
reports: reports:
junit: .unittest.$PY_VERSION.xml junit: results.$PY_VERSION/junit.xml
unittest:3.8: Unit Tests (Py 3.10):
extends: .unittest extends: [.python, .unittest]
variables: variables:
PY_VERSION: '3.8' PY_VERSION: '3.10'
unittest:3.9: Unit Tests (Py 3.11):
extends: .unittest extends: [.python, .unittest]
variables: variables:
PY_VERSION: '3.9' PY_VERSION: '3.11'
unittest:3.10: Unit Tests (Py 3.12):
extends: .unittest extends: [.python, .unittest]
variables: variables:
PY_VERSION: '3.10-rc' PY_VERSION: '3.12'
publish:unittests: Publish Unit Tests:
stage: publish stage: deploy
extends: [.python]
when: always when: always
dependencies: &unittests needs: &unittests
- unittest:3.8 - Unit Tests (Py 3.10)
- unittest:3.9 - Unit Tests (Py 3.11)
- unittest:3.10 - Unit Tests (Py 3.12)
needs: *unittests
script: script:
- pip install --upgrade junit2html - pip install --upgrade junit2html
- mkdir -p unittest - mkdir -p unittest
- python util/junit_merge.py .unittest.*.xml > .unittest.xml - python util/junit_merge.py results.*/junit.xml > junit.xml
- junit2html .unittest.xml unittest/index.html - junit2html junit.xml unittest/index.html
artifacts: artifacts:
when: always when: always
paths: paths:
- .unittest.xml - junit.xml
- unittest - unittest
reports:
junit: junit.xml
# Aggregate Coverage
coverage: Aggregate Coverage:
stage: acceptance stage: test
extends: [.python]
when: always when: always
dependencies: *unittests
needs: *unittests needs: *unittests
script: script:
- pip install --upgrade coverage - pip install --upgrade coverage
- coverage combine - coverage combine results.*/coverage.db
- coverage report - coverage report
coverage: '/^TOTAL.* ([0-9.]+\%)$/' coverage: '/^TOTAL.* ([0-9.]+\%)$/'
artifacts: artifacts:
when: always when: always
paths: paths:
- .coverage - results
publish:coverage: Publish Coverage:
stage: publish stage: deploy
extends: [.python]
when: always when: always
dependencies: [coverage] needs: [Aggregate Coverage]
needs: [coverage]
script: script:
- pip install --upgrade coverage - pip install --upgrade coverage
- coverage html --fail-under=0 -d coverage - coverage html --fail-under=0 -d results/coverage.html
- coverage xml --fail-under=0 -o coverage/coverage.xml - coverage xml --fail-under=0 -o results/coverage.xml
artifacts: artifacts:
when: always when: always
paths: paths:
- coverage - results
# Quality Assurance
Code Analysis:
stage: test
variables:
FROM_REF: $CI_DEFAULT_BRANCH
rules:
- if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH
variables:
FROM_REF: $CI_COMMIT_BEFORE_SHA
- when: always
before_script:
- pip install pre-commit
script:
- git fetch $CI_REPOSITORY_URL $FROM_REF:FROM_REF -f
- pre-commit run
--hook-stage=commit
--from-ref=FROM_REF
--to-ref=${CI_COMMIT_SHA}
Commit Graph Analysis:
stage: test
variables:
FROM_REF: $CI_DEFAULT_BRANCH
rules:
- if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH
variables:
FROM_REF: $CI_COMMIT_BEFORE_SHA
- if: $CI_PIPELINE_TRIGGERED == "merge_request_event"
before_script:
- pip install pre-commit
script:
- pre-commit run
--hook-stage=push
--from-ref=FROM_REF
--to-ref=${CI_COMMIT_SHA}
# Package publishing
Check Tag:
stage: test
rules:
- if: $CI_COMMIT_TAG =~ /^v[0-9]/
script:
- test `./setup.py --version` == ${CI_COMMIT_TAG#v}
Build Packages:
stage: build
script:
- ./setup.py bdist_wheel sdist
artifacts:
paths:
- dist
Upload Packages:
stage: publish
rules:
- if: $CI_COMMIT_TAG =~ /^v[0-9]/
dependencies:
- Build Packages
script:
- pip install twine
- TWINE_PASSWORD=$CI_JOB_TOKEN
TWINE_USERNAME=gitlab-ci-token
twine upload
--verbose
--non-interactive
--repository-url $CI_API_V4_URL/projects/$CI_PROJECT_ID/packages/pypi
dist/*
...@@ -4,9 +4,10 @@ repos: ...@@ -4,9 +4,10 @@ repos:
- repo: meta - repo: meta
hooks: hooks:
- id: check-hooks-apply - id: check-hooks-apply
- id: check-useless-excludes
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.4.0 rev: v4.6.0
hooks: hooks:
- id: check-added-large-files - id: check-added-large-files
- id: check-case-conflict - id: check-case-conflict
...@@ -17,77 +18,84 @@ repos: ...@@ -17,77 +18,84 @@ repos:
- id: debug-statements - id: debug-statements
- id: destroyed-symlinks - id: destroyed-symlinks
- id: end-of-file-fixer - id: end-of-file-fixer
stages: [commit] stages: [commit, manual]
- id: fix-byte-order-marker - id: fix-byte-order-marker
- id: fix-encoding-pragma - id: fix-encoding-pragma
args: [--remove] args: [--remove]
- id: mixed-line-ending - id: mixed-line-ending
args: [--fix=lf] args: [--fix=lf]
stages: [commit, manual]
- id: trailing-whitespace - id: trailing-whitespace
exclude_types: [markdown, plain-text] exclude_types: [markdown, plain-text]
stages: [commit] stages: [commit, manual]
- repo: https://github.com/jorisroovers/gitlint - repo: https://github.com/jorisroovers/gitlint
rev: v0.15.0 rev: v0.19.1
hooks: hooks:
- id: gitlint - id: gitlint
- repo: https://code.kodo.org.uk/dom/pre-commit-hooks - repo: https://code.kodo.org.uk/dom/pre-commit-hooks
rev: v0.6 rev: v0.6.1
hooks: hooks:
- id: check-executable-modes - id: check-executable-modes
- id: check-for-squash - id: check-for-squash
- id: copyright-notice - id: copyright-notice
args: [--min-size=1] args: [--min-size=1]
exclude: setup\.py stages: [commit, manual]
- id: protect-first-parent - id: protect-first-parent
- repo: https://github.com/pre-commit/pygrep-hooks - repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.8.0 rev: v1.10.0
hooks: hooks:
- id: python-no-eval - id: python-no-eval
- id: python-no-log-warn - id: python-no-log-warn
- id: python-use-type-annotations - id: python-use-type-annotations
- repo: https://github.com/hakancelik96/unimport - repo: https://github.com/hakancelikdev/unimport
rev: 0.8.4 rev: 1.2.1
hooks: hooks:
- id: unimport - id: unimport
args: [--remove, --exclude=types.py|__init__.py] args: [--remove, --exclude=types.py|__init__.py]
stages: [commit, manual] stages: [commit, manual]
- repo: https://github.com/timothycrosley/isort - repo: https://github.com/pycqa/isort
rev: 5.7.0 rev: 5.13.2
hooks: hooks:
- id: isort - id: isort
types: [python] types: [python]
stages: [commit, manual] stages: [commit, manual]
- repo: https://github.com/asottile/add-trailing-comma - repo: https://github.com/asottile/add-trailing-comma
rev: v2.1.0 rev: v3.1.0
hooks: hooks:
- id: add-trailing-comma - id: add-trailing-comma
args: [--py36-plus] args: [--py36-plus]
stages: [commit, manual] stages: [commit, manual]
- repo: https://gitlab.com/pycqa/flake8 - repo: https://github.com/astral-sh/ruff-pre-commit
rev: 3.8.3 rev: v0.5.5
hooks: hooks:
- id: flake8 - id: ruff
args: ["--config=setup.cfg"] exclude: "^util/"
additional_dependencies: args: [--fix, --unsafe-fixes]
- flake8-bugbear
- flake8-docstrings
- flake8-print
- flake8-requirements
- flake8-return
- flake8-sfs
- flake8-tabs
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.910 rev: v1.11.1
hooks: hooks:
- id: mypy - id: mypy
args: [--config-file=setup.cfg] args:
additional_dependencies: [anyio, trio-typing] - --python-version=3.10
exclude: setup\.py|test_.* - --follow-imports=silent
- wpa_supplicant
pass_filenames: false
additional_dependencies: &type-deps
- anyio ~=4.0
- trio-typing
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.373
hooks:
- id: pyright
args: ["--pythonversion=3.11"]
pass_filenames: false
additional_dependencies: *type-deps
line-length = 92
indent-width = 1 # Used for line length violations
[lint]
select = [
# pyflakes
# --------
# ENABLE "Undefined name %s in __all__"
"F822",
# ENABLE "Local variable %s ... referenced before assignment"
"F823",
# ENABLE "Local variable %s is assigned to but never used"
"F841",
# ENABLE "raise NotImplemented should be raise NotImplementedError"
# mypy has particular trouble with this one:
# https://github.com/python/mypy/issues/5710
"F901",
# pycodestyle
# -----------
# Warnings not considered, many are not relevant to Python ~=3.9 and will
# cause syntax errors anyway, others concern whitespace which is fixed by
# a pre-commit hook.
"E",
# mccabe
# ------
"C90",
# pydocstyle
# ----------
# Missing docstrings
"D1",
# Whitespace Issues
"D2",
# ENABLE "Use “””triple double quotes”””"
"D300",
# First line should be descriptive, imperative and capitalised
"D401", "D402", "D403", "D404",
# ENABLE "Function/Method decorated with @overload shouldn’t contain a docstring"
"D418",
# flake8-bugbear
# --------------
# The bulk of bugbear's checks are useful
"B0",
# Various others
# --------------
"UP", "BLE", "FBT", "A", "COM", "C4", "DTZ", "ISC", "LOG", "G", "PIE", "T",
"Q", "RSE", "RET", "SLF", "SLOT", "SIM", "TD", "ANN", #"FA",
# Nice to have, needs fixing in several places though...
# "EM", "TCH", "PTH", "PGH",
]
ignore = [
# pycodestyle
# -----------
# DISABLE "Indentation contains mixed spaces and tabs"
# Will cause a syntax error if critical, otherwise in docstrings it is
# sometimes nice to use different indentation for "outer" (code) indentation
# and "inner" (documentation) indentation.
"E101",
# DISABLE "Continuation line missing indentation or outdented"
# "E122",
# DISABLE "Missing whitespace around bitwise or shift operator"
"E227",
# DISABLE "missing whitespace around arithmetic operator"
"E226",
# DISABLE "Line too long"
# Prefer B950 implementation
"E501",
# DISABLE "Multiple statements on one line (colon)"
"E701",
# DISABLE "Multiple statements on one line (def)"
# Doesn't work well with @overload definitions
# "E704",
# pydocstyle
# ----------
# DISABLE "Missing docstring in magic method"
# Magic/dunder methods are well-known
"D105",
# DISABLE "Missing docstring in __init__"
# Document basic construction in the class docstring
"D107",
# DISABLE "One-line docstring should fit on one line with quotes"
# Prefer top-and-bottom style always
"D200",
# DISABLE "1 blank line required before class docstring"
"D203",
# DISABLE "Docstring should be indented with spaces, not tabs"
# Tabs, absolutely always
"D206",
# DISABLE "Multi-line docstring summary should start at the first line"
"D212",
# flake8-bugbear
# --------------
# DISABLE "Do not use mutable data structures for argument defaults [...]"
# Would be nice if could take into account use as a non-mutable type
"B006",
# DISABLE "release is an empty method in an abstract base class, [...]"
# Until abstract methods are optional, empty optional "abstract" methods
# stay
"B027",
# Use named-tuples (preferably class based) for data-only classes
# "B903",
# Replacement for E501
# "B950",
# flake8-return
# -------------
# DISABLE "missing explicit return at the end of function able to return
# non-None value"
# Mypy will report this, plugin also cannot do exhaustiveness check of match
# block, leading to false-positives.
"RET503",
# DISABLE "Missing type annotation for `%` in {method|classmethod}"
# Don't type 'self' or 'cls'
"ANN101", "ANN102",
# DISABLE "Boolean positional value in function call"
# Too many stdlib functions take a single positional-only boolean. ruff
# can't interpret function signatures to ignore these and doesn't understand
# types to allow-list methods.
"FBT003",
# DISABLE "Implicitly concatenated string literals over multiple lines"
# It sometimes looks better to do this than introduce unecessary
# parentheses.
"ISC002",
# Unfortunately a lot of single quotes strings used in this project already
"Q000",
]
[lint.per-file-ignores]
"**/__init__.py" = ["D104"]
"**/__main__.py" = ["D100", "E702"]
"**/_*.py" = ["D1"]
"examples/**.py" = ["T"]
"tests/*" = ["D1"]
"doc/*" = ["D"]
"README.md" = ["D"]
[build-system] [build-system]
requires = ["setuptools>=40.8.0", "wheel"] requires = ["flit_core ~=3.8"]
build-backend = "setuptools.build_meta:__legacy__" build-backend = "flit_core.buildapi"
[project]
name = "wpa-supplicant-client"
version = "0.3.0"
description = "A client package for connecting to, configuring and controlling wpa_supplicant daemons"
license = {file = "LICENCE.txt"}
authors = [
{name = "Dom Sekotill", email = "dom.sekotill@kodo.org.uk"},
]
classifiers = [
"Intended Audience :: Developers",
"Operating System :: POSIX",
]
requires-python = "~=3.10"
dependencies = [
"anyio ~=4.1",
]
[project.optional-dependencies]
test = [
"trio",
]
[project.urls]
Repository = "https://code.kodo.org.uk/dom/wpa-supplicant-client"
Issues = "https://code.kodo.org.uk/dom/wpa-supplicant-client/-/issues"
[tool.flit.module]
name = "wpa_supplicant"
[tool.isort] [tool.isort]
force_single_line = true force_single_line = true
line_length = 92
[tool.unimport] [tool.unimport]
ignore-init = true ignore-init = true
[tool.mypy]
allow_redefinition = true
explicit_package_bases = true
implicit_reexport = true
strict = true
warn_unreachable = true
warn_unused_configs = true
[tool.pyright]
include = ["wpa_supplicant"]
typeCheckingMode = "strict"
reportMissingModuleSource = "none"
reportRedeclaration = "none"
reportUnknownMemberType = "warning"
reportUnusedImport = "warning"
[tool.coverage.run]
data_file = "results/coverage.db"
branch = true
source = ["wpa_supplicant"]
[tool.coverage.report]
precision = 2
skip_empty = true
exclude_lines = [
"pragma: no-cover",
"if .*\\b__name__\\b",
"if .*\\bTYPE_CHECKING\\b",
"class .*(.*\\bProtocol\\b.*):",
"def __repr__",
"@overload",
"@(abc\\.)abstractmethod",
]
partial_branches = [
"pragma: no-branch",
"if .*\\b__debug__\\b",
]
[tool.coverage.json]
output = "results/coverage.json"
show_contexts = true
[tool.coverage.xml]
output = "results/coverage.xml"
[tool.coverage.html]
directory = "results/coverage"
show_contexts = true
[metadata]
name = wpa-supplicant-client
version = attr: wpa_supplicant.__version__
author = Dom Sekotill
author_email = dom.sekotill@kodo.org.uk
description = A client package for connecting to, configuring and controlling wpa_supplicant daemons
long_description = file: README.md
long_description_content_type = text/markdown
url = 'https://code.kodo.org.uk/dom/wpa-supplicant-client.git'
license = Apache-2.0
license_files =
LICENCE.txt
classifiers =
Development Status :: 2 - Pre-Alpha
Intended Audience :: Developers
License :: OSI Approved
License :: OSI Approved :: Apache Software License
Natural Language :: English
Operating System :: POSIX
Programming Language :: Python
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Typing::Typed
[options]
python_requires = >= 3.8
packages = find:
setup_requires =
setuptools >= 40.6
install_requires =
anyio ~=3.0
[options.packages.find]
include =
wpa_supplicant
wpa_supplicant.*
[options.package_data]
wpa_supplicant = py.typed
[options.extras_require]
test =
nose2[coverage_plugin]
trio
[unittest]
start-dir = tests/unit
verbose = True
plugins =
nose2.plugins.junitxml
nose2.plugins.testid
[coverage]
always-on = True
coverage = wpa_supplicant
coverage-report = term
[coverage:run]
branch = True
[coverage:report]
fail_under = 80
precision = 2
show_missing = True
omit = **/__main__.py
exclude_lines =
pragma: no cover
if __name__ == .__main__.:
def __repr__
__version__ =
@(.*\.)?abstract((static|class)?method|property)
except ImportError:\s*(...|pass)
[coverage:xml]
output = .coverage.xml
[coverage:html]
directory = .coverage.html.d
[junit-xml]
always-on = True
path = .unittest.xml
[testid]
always-on = True
[log-capture]
always-on = True
[isort]
force_single_line = true
[mypy]
strict = true
warn_unused_configs = True
warn_unreachable = true
implicit_reexport = true
[flake8]
max-line-length = 92
max-doc-length = 92
use-flake8-tabs = true
blank-lines-indent = never
indent-tabs-def = 1
format = pylint
select = C,D,E,ET,F,SFS,T,W,WT
extend-exclude =
setup.py
per-file-ignores =
setup.py: D100, E702
tests/*.py: D100, C801
**/__init__.py: D104, F401, F403
**/__main__.py: D100, E702
**/_*.py: D
ignore =
;[ Missing docstring in public method ]
; Handled by pylint, which does it better
D102
;[ Missing docstring in magic method ]
; Magic/dunder methods are well-known
D105
;[ Misisng docstring in __init__ ]
; Document basic construction in the class docstring
D107
;[ One-line docstring should fit on one line with quotes ]
; Prefer top-and-bottom style always
D200
;[ Docstring should be indented with spaces, not tabs ]
; Tabs, absolutely always
D206
;[ Use u""" for Unicode docstrings ]
; This must be for Python 2?
D302
;[ First line should end with a period ]
; First line should *NEVER* end with a period
D400
;[ First line should be in the imperative mood ]
; I like this for functions and methods, not for properties. This stands until
; pydocstyle splits a new code for properties or flake8 adds some way of
; filtering codes with line regexes like golangci-lint.
D401
;[ Line too long ]
; Prefer B950 implementation
E501
;[ multiple statements on one line (%s) ]
E701 E704
;[ unexpected number of spaces at start of statement line ]
;[ unexpected number of tabs and spaces at start of statement line ]
; Don't want spaces...
ET122 ET128
;[ Line break before binary operator ]
; Not considered current
W503
;[ Format-method string formatting ]
; Allow this style
SFS201
;[ f-string string formatting ]
; Allow this style
SFS301
include =
;[ First word of the docstring should not be This ]
D404
; flake8-bugbear plugin
; B950 is a replacement for E501
B0 B903 B950
; vim: sw=2 sts=2 expandtab
#!/usr/bin/env python3
"""Setuptools entrypoint"""
from setuptools import setup
setup()
# Copyright 2019-2021 Dom Sekotill <dom.sekotill@kodo.org.uk> # Copyright 2019-2021, 2024 Dom Sekotill <dom.sekotill@kodo.org.uk>
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -16,74 +16,27 @@ ...@@ -16,74 +16,27 @@
Anyio helpers for unit tests Anyio helpers for unit tests
""" """
import sys from typing import Awaitable
from functools import wraps
from typing import Any
from typing import Callable
from typing import Coroutine
from typing import Literal
from typing import Tuple
from typing import Union
from unittest import TestCase
from unittest import mock from unittest import mock
from warnings import warn
import anyio import anyio
try:
import trio as _ # noqa
USE_TRIO = True
except ImportError:
USE_TRIO = False
Backend = Union[Literal['asyncio'], Literal['trio']] def _delay_side_effect(delay: float) -> Awaitable[None]:
async def coro(*a: object, **k: object) -> None:
await anyio.sleep(delay)
return coro
py_version = sys.version_info[:2]
AsyncTestFunc = Callable[..., Coroutine[Any, Any, None]] def patch_connect(delay: float = 0.0) -> mock._patch:
TestFunc = Callable[..., None] return mock.patch(
"wpa_supplicant.client.base.connect_unix_datagram",
side_effect=_delay_side_effect(delay),
)
def with_anyio(*backends: Backend, timeout: int = 10) -> Callable[[AsyncTestFunc], TestFunc]: def patch_send(delay: float = 0.0) -> mock._patch:
""" return mock.patch(
Create a wrapping decorator to run asynchronous test functions "wpa_supplicant.client.base.BaseClient.send_command",
""" side_effect=_delay_side_effect(delay),
if not backends: )
backends = ('asyncio',)
def decorator(testfunc: AsyncTestFunc) -> TestFunc:
async def test_async_wrapper(tc: TestCase, args: Tuple[mock.Mock]) -> None:
with anyio.fail_after(timeout):
await testfunc(tc, *args)
@wraps(testfunc)
def test_wrapper(tc: TestCase, *args: mock.Mock) -> None:
for backend in backends:
if backend == 'trio' and not USE_TRIO:
warn(
f"not running {testfunc.__name__} with trio; package is missing",
)
continue
with tc.subTest(f"backend: {backend}"):
anyio.run(test_async_wrapper, tc, args, backend=backend)
return test_wrapper
return decorator
class AsyncMock(mock.Mock):
"""
A Mock class that acts as a coroutine when called
"""
def __init__(self, *args: Any, delay: float = 0.0, **kwargs: Any):
mock._safe_super(AsyncMock, self).__init__(*args, **kwargs) # type: ignore
self.delay = delay
async def __call__(_mock_self, *args: Any, **kwargs: Any) -> Any:
_mock_self._mock_check_sig(*args, **kwargs)
if py_version >= (3, 8):
_mock_self._increment_mock_call(*args, **kwargs)
await anyio.sleep(_mock_self.delay)
return _mock_self._mock_call(*args, **kwargs)
# Copyright 2021 Dom Sekotill <dom.sekotill@kodo.org.uk> # Copyright 2021, 2024 Dom Sekotill <dom.sekotill@kodo.org.uk>
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -18,22 +18,20 @@ Test connecting and communicating with a server ...@@ -18,22 +18,20 @@ Test connecting and communicating with a server
import os import os
import sys import sys
from unittest import TestCase import unittest
from tests._anyio import with_anyio
from tests.integration.util import start_server from tests.integration.util import start_server
from wpa_supplicant.client import GlobalClient from wpa_supplicant.client import GlobalClient
class Tests(TestCase): class Tests(unittest.IsolatedAsyncioTestCase):
""" """
Tests against live wpa_suppplicant servers Tests against live wpa_suppplicant servers
The 'wpa_supplicant' executable is required in a PATH directory for these tests to work. The 'wpa_supplicant' executable is required in a PATH directory for these tests to work.
""" """
@with_anyio('asyncio', 'trio') async def test_connect(self) -> None:
async def test_connect(self):
""" """
Test connecting to the global wpa_supplicant control socket Test connecting to the global wpa_supplicant control socket
""" """
...@@ -42,8 +40,7 @@ class Tests(TestCase): ...@@ -42,8 +40,7 @@ class Tests(TestCase):
ifaces = await client.list_interfaces() ifaces = await client.list_interfaces()
assert len(ifaces) == 0 assert len(ifaces) == 0
@with_anyio('asyncio', 'trio') async def test_new_interface(self) -> None:
async def test_new_interface(self):
""" """
Test adding a wireless interface and scanning for stations Test adding a wireless interface and scanning for stations
......
# Copyright 2019-2021 Dom Sekotill <dom.sekotill@kodo.org.uk> # Copyright 2019-2021, 2024 Dom Sekotill <dom.sekotill@kodo.org.uk>
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -17,74 +17,62 @@ Test cases for wpa_supplicant.client.base.BaseClient ...@@ -17,74 +17,62 @@ Test cases for wpa_supplicant.client.base.BaseClient
""" """
import unittest import unittest
from unittest import mock from unittest.mock import AsyncMock
import anyio import anyio
from tests import _anyio as anyio_mock from tests._anyio import patch_connect
from tests._anyio import patch_send
from wpa_supplicant import errors from wpa_supplicant import errors
from wpa_supplicant.client import base from wpa_supplicant.client import base
@mock.patch( class ConnectTests(unittest.IsolatedAsyncioTestCase):
"wpa_supplicant.client.base.connect_unix_datagram",
new_callable=anyio_mock.AsyncMock,
)
@mock.patch(
"wpa_supplicant.client.base.BaseClient.send_command",
new_callable=anyio_mock.AsyncMock,
)
class ConnectTests(unittest.TestCase):
""" """
Tests for the connect() method Tests for the connect() method
""" """
@anyio_mock.with_anyio() async def test_connect(self) -> None:
async def test_connect(self, _, connect_mock):
""" """
Check connect() calls socket.connect() Check connect() calls socket.connect()
""" """
async with base.BaseClient() as client: with patch_connect() as connect_mock, patch_send():
await client.connect("foo") async with base.BaseClient() as client:
await client.connect("foo")
connect_mock.assert_called_once_with("foo") connect_mock.assert_awaited_once_with("foo")
@anyio_mock.with_anyio() async def test_connect_timeout_1(self) -> None:
async def test_connect_timeout_1(self, _, connect_mock):
""" """
Check a socket.connect() delay causes TimeoutError to be raised Check a socket.connect() delay causes TimeoutError to be raised
""" """
connect_mock.delay = 2 with patch_connect(2.0), patch_send():
async with base.BaseClient() as client:
async with base.BaseClient() as client: with self.assertRaises(TimeoutError):
with self.assertRaises(TimeoutError): await client.connect("foo")
await client.connect("foo")
@anyio_mock.with_anyio() async def test_connect_timeout_2(self) -> None:
async def test_connect_timeout_2(self, send_mock, _):
""" """
Check a send/recv delay causes a TimeoutError to be raised Check a send/recv delay causes a TimeoutError to be raised
""" """
send_mock.delay = 2 with patch_connect(), patch_send(2.0):
async with base.BaseClient() as client:
async with base.BaseClient() as client: with self.assertRaises(TimeoutError):
with self.assertRaises(TimeoutError): await client.connect("foo")
await client.connect("foo")
class SendMessageTests(unittest.TestCase): class SendMessageTests(unittest.IsolatedAsyncioTestCase):
""" """
Tests for the send_command() method Tests for the send_command() method
""" """
def setUp(self): def setUp(self) -> None:
self.client = client = base.BaseClient() self.client = client = base.BaseClient()
client.sock = anyio_mock.AsyncMock(spec=anyio.abc.SocketStream) client.sock = AsyncMock(spec=anyio.abc.SocketStream)
client.sock.send.return_value = None client.sock.send.return_value = None
assert isinstance(client.sock, anyio.abc.SocketStream) assert isinstance(client.sock, anyio.abc.SocketStream)
@anyio_mock.with_anyio() async def test_simple(self) -> None:
async def test_simple(self):
""" """
Check that a response is processed after a command Check that a response is processed after a command
""" """
...@@ -92,8 +80,7 @@ class SendMessageTests(unittest.TestCase): ...@@ -92,8 +80,7 @@ class SendMessageTests(unittest.TestCase):
client.sock.receive.return_value = b"OK" client.sock.receive.return_value = b"OK"
assert await client.send_command("SOME_COMMAND") is None assert await client.send_command("SOME_COMMAND") is None
@anyio_mock.with_anyio() async def test_simple_expect(self) -> None:
async def test_simple_expect(self):
""" """
Check that an alternate expected response is processed Check that an alternate expected response is processed
""" """
...@@ -101,8 +88,7 @@ class SendMessageTests(unittest.TestCase): ...@@ -101,8 +88,7 @@ class SendMessageTests(unittest.TestCase):
client.sock.receive.return_value = b"PONG" client.sock.receive.return_value = b"PONG"
assert await client.send_command("PING", expect="PONG") is None assert await client.send_command("PING", expect="PONG") is None
@anyio_mock.with_anyio() async def test_simple_no_expect(self) -> None:
async def test_simple_no_expect(self):
""" """
Check that an unexpected response raises an UnexpectedResponseError Check that an unexpected response raises an UnexpectedResponseError
""" """
...@@ -113,8 +99,7 @@ class SendMessageTests(unittest.TestCase): ...@@ -113,8 +99,7 @@ class SendMessageTests(unittest.TestCase):
with self.assertRaises(errors.UnexpectedResponseError): with self.assertRaises(errors.UnexpectedResponseError):
await client.send_command("PING", expect="PONG") await client.send_command("PING", expect="PONG")
@anyio_mock.with_anyio() async def test_simple_convert(self) -> None:
async def test_simple_convert(self):
""" """
Check that a response is passed through a converter if given Check that a response is passed through a converter if given
""" """
...@@ -127,8 +112,7 @@ class SendMessageTests(unittest.TestCase): ...@@ -127,8 +112,7 @@ class SendMessageTests(unittest.TestCase):
["FOO", "BAR", "BAZ"], ["FOO", "BAR", "BAZ"],
) )
@anyio_mock.with_anyio() async def test_simple_convert_over_expect(self) -> None:
async def test_simple_convert_over_expect(self):
""" """
Check that 'convert' overrides 'expect' Check that 'convert' overrides 'expect'
""" """
...@@ -141,8 +125,7 @@ class SendMessageTests(unittest.TestCase): ...@@ -141,8 +125,7 @@ class SendMessageTests(unittest.TestCase):
["FOO", "BAR", "BAZ"], ["FOO", "BAR", "BAZ"],
) )
@anyio_mock.with_anyio() async def test_simple_fail(self) -> None:
async def test_simple_fail(self):
""" """
Check that a response of 'FAIL' causes CommandFailed to be raised Check that a response of 'FAIL' causes CommandFailed to be raised
""" """
...@@ -151,8 +134,7 @@ class SendMessageTests(unittest.TestCase): ...@@ -151,8 +134,7 @@ class SendMessageTests(unittest.TestCase):
with self.assertRaises(errors.CommandFailed): with self.assertRaises(errors.CommandFailed):
await client.send_command("SOME_COMMAND") await client.send_command("SOME_COMMAND")
@anyio_mock.with_anyio() async def test_simple_bad_command(self) -> None:
async def test_simple_bad_command(self):
""" """
Check that a response of 'UNKNOWN COMMAND' causes ValueError to be raised Check that a response of 'UNKNOWN COMMAND' causes ValueError to be raised
""" """
...@@ -161,8 +143,7 @@ class SendMessageTests(unittest.TestCase): ...@@ -161,8 +143,7 @@ class SendMessageTests(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
await client.send_command("SOME_COMMAND") await client.send_command("SOME_COMMAND")
@anyio_mock.with_anyio() async def test_interleaved(self) -> None:
async def test_interleaved(self):
""" """
Check that messages are processed alongside replies Check that messages are processed alongside replies
""" """
...@@ -175,8 +156,7 @@ class SendMessageTests(unittest.TestCase): ...@@ -175,8 +156,7 @@ class SendMessageTests(unittest.TestCase):
] ]
assert await client.send_command("SOME_COMMAND") is None assert await client.send_command("SOME_COMMAND") is None
@anyio_mock.with_anyio() async def test_unexpected(self) -> None:
async def test_unexpected(self):
""" """
Check that unexpected replies are logged cleanly Check that unexpected replies are logged cleanly
""" """
...@@ -190,8 +170,7 @@ class SendMessageTests(unittest.TestCase): ...@@ -190,8 +170,7 @@ class SendMessageTests(unittest.TestCase):
] ]
assert await client.event("CTRL-EVENT-EXAMPLE") assert await client.event("CTRL-EVENT-EXAMPLE")
@anyio_mock.with_anyio() async def test_unconnected(self) -> None:
async def test_unconnected(self):
""" """
Check that calling send_command() on an unconnected client raises RuntimeError Check that calling send_command() on an unconnected client raises RuntimeError
""" """
...@@ -200,8 +179,7 @@ class SendMessageTests(unittest.TestCase): ...@@ -200,8 +179,7 @@ class SendMessageTests(unittest.TestCase):
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
await client.send_command("SOME_COMMAND") await client.send_command("SOME_COMMAND")
@anyio_mock.with_anyio() async def test_multi_task(self) -> None:
async def test_multi_task(self):
""" """
Check that calling send_command() from multiple tasks works as expected Check that calling send_command() from multiple tasks works as expected
""" """
...@@ -213,7 +191,7 @@ class SendMessageTests(unittest.TestCase): ...@@ -213,7 +191,7 @@ class SendMessageTests(unittest.TestCase):
(0.0, b"OK"), # Response to DETACH (0.0, b"OK"), # Response to DETACH
]) ])
async def recv(): async def recv() -> bytes:
delay, data = next(recv_responses) delay, data = next(recv_responses)
await anyio.sleep(delay) await anyio.sleep(delay)
return data return data
...@@ -222,7 +200,7 @@ class SendMessageTests(unittest.TestCase): ...@@ -222,7 +200,7 @@ class SendMessageTests(unittest.TestCase):
client.sock.receive.side_effect = recv client.sock.receive.side_effect = recv
@task_group.start_soon @task_group.start_soon
async def wait_for_event(): async def wait_for_event() -> None:
self.assertTupleEqual( self.assertTupleEqual(
await client.event("CTRL-FOO"), await client.event("CTRL-FOO"),
(base.EventPriority.INFO, "CTRL-FOO", None), (base.EventPriority.INFO, "CTRL-FOO", None),
...@@ -235,8 +213,7 @@ class SendMessageTests(unittest.TestCase): ...@@ -235,8 +213,7 @@ class SendMessageTests(unittest.TestCase):
# At this point the response to SOME_COMMAND1 is still delayed # At this point the response to SOME_COMMAND1 is still delayed
await client.send_command("SOME_COMMAND2", expect="REPLY2") await client.send_command("SOME_COMMAND2", expect="REPLY2")
@anyio_mock.with_anyio() async def test_multi_task_decode_error(self) -> None:
async def test_multi_task_decode_error(self):
""" """
Check that decode errors closes the socket and causes all tasks to raise EOFError Check that decode errors closes the socket and causes all tasks to raise EOFError
""" """
...@@ -251,27 +228,26 @@ class SendMessageTests(unittest.TestCase): ...@@ -251,27 +228,26 @@ class SendMessageTests(unittest.TestCase):
client.sock.receive.side_effect = recv_responses client.sock.receive.side_effect = recv_responses
@task_group.start_soon @task_group.start_soon
async def wait_for_event(): async def wait_for_event() -> None:
with self.assertRaises(anyio.ClosedResourceError): with self.assertRaises(anyio.ClosedResourceError):
await client.event("CTRL-FOO"), await client.event("CTRL-FOO")
await anyio.sleep(0.1) # Ensure send_command("ATTACH") has been sent await anyio.sleep(0.1) # Ensure send_command("ATTACH") has been sent
with self.assertRaises(anyio.ClosedResourceError): with self.assertRaises(anyio.ClosedResourceError):
await client.send_command("SOME_COMMAND", expect="REPLY") await client.send_command("SOME_COMMAND", expect="REPLY")
class EventTests(unittest.TestCase): class EventTests(unittest.IsolatedAsyncioTestCase):
""" """
Tests for the event() method Tests for the event() method
""" """
def setUp(self): def setUp(self) -> None:
self.client = client = base.BaseClient() self.client = client = base.BaseClient()
client.sock = anyio_mock.AsyncMock() client.sock = AsyncMock()
client.sock.send.return_value = None client.sock.send.return_value = None
@anyio_mock.with_anyio() async def test_simple(self) -> None:
async def test_simple(self):
""" """
Check that an awaited message is returned when is arrives Check that an awaited message is returned when is arrives
""" """
...@@ -287,8 +263,7 @@ class EventTests(unittest.TestCase): ...@@ -287,8 +263,7 @@ class EventTests(unittest.TestCase):
assert evt == "CTRL-EVENT-EXAMPLE" assert evt == "CTRL-EVENT-EXAMPLE"
assert args is None assert args is None
@anyio_mock.with_anyio() async def test_multiple(self) -> None:
async def test_multiple(self):
""" """
Check that an awaited messages is returned when it arrives between others Check that an awaited messages is returned when it arrives between others
""" """
...@@ -307,8 +282,7 @@ class EventTests(unittest.TestCase): ...@@ -307,8 +282,7 @@ class EventTests(unittest.TestCase):
assert evt == "CTRL-EVENT-EXAMPLE" assert evt == "CTRL-EVENT-EXAMPLE"
assert args is None assert args is None
@anyio_mock.with_anyio() async def test_wait_multiple(self) -> None:
async def test_wait_multiple(self):
""" """
Check that the first of several awaited events is returned Check that the first of several awaited events is returned
""" """
...@@ -330,8 +304,7 @@ class EventTests(unittest.TestCase): ...@@ -330,8 +304,7 @@ class EventTests(unittest.TestCase):
assert evt == "CTRL-EVENT-EXAMPLE3" assert evt == "CTRL-EVENT-EXAMPLE3"
assert args is None assert args is None
@anyio_mock.with_anyio() async def test_interleaved(self) -> None:
async def test_interleaved(self):
""" """
Check that messages are processed as well as replies Check that messages are processed as well as replies
""" """
...@@ -357,8 +330,7 @@ class EventTests(unittest.TestCase): ...@@ -357,8 +330,7 @@ class EventTests(unittest.TestCase):
assert await client.send_command("SOME_COMMAND", expect="FOO") is None assert await client.send_command("SOME_COMMAND", expect="FOO") is None
@anyio_mock.with_anyio() async def test_unconnected(self) -> None:
async def test_unconnected(self):
""" """
Check that calling event() on an unconnected client raises RuntimeError Check that calling event() on an unconnected client raises RuntimeError
""" """
......
# Copyright 2019-2021 Dom Sekotill <dom.sekotill@kodo.org.uk> # Copyright 2019-2021, 2024 Dom Sekotill <dom.sekotill@kodo.org.uk>
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -18,25 +18,24 @@ Test cases for wpa_supplicant.client.GlobalClient ...@@ -18,25 +18,24 @@ Test cases for wpa_supplicant.client.GlobalClient
import pathlib import pathlib
import unittest import unittest
from unittest.mock import AsyncMock
from unittest.mock import patch from unittest.mock import patch
from tests import _anyio as anyio_mock
from wpa_supplicant.client import GlobalClient from wpa_supplicant.client import GlobalClient
from wpa_supplicant.client import InterfaceClient from wpa_supplicant.client import InterfaceClient
class InterfaceMethodsTests(unittest.TestCase): class InterfaceMethodsTests(unittest.IsolatedAsyncioTestCase):
""" """
Tests for the *_interface(s?) methods Tests for the *_interface(s?) methods
""" """
def setUp(self): def setUp(self) -> None:
self.client = client = GlobalClient() self.client = client = GlobalClient()
client.sock = anyio_mock.AsyncMock() client.sock = AsyncMock()
client.sock.send.return_value = None client.sock.send.return_value = None
@anyio_mock.with_anyio() async def test_connect(self) -> None:
async def test_connect(self):
""" """
Check that connect sets ctrl_dir Check that connect sets ctrl_dir
""" """
...@@ -45,7 +44,7 @@ class InterfaceMethodsTests(unittest.TestCase): ...@@ -45,7 +44,7 @@ class InterfaceMethodsTests(unittest.TestCase):
with patch( with patch(
"wpa_supplicant.client.base.BaseClient.connect", "wpa_supplicant.client.base.BaseClient.connect",
new_callable=anyio_mock.AsyncMock, new_callable=AsyncMock,
): ):
await client1.connect("/tmp/foo/bar") await client1.connect("/tmp/foo/bar")
await client2.connect(pathlib.Path("/tmp/foo/bar")) await client2.connect(pathlib.Path("/tmp/foo/bar"))
...@@ -56,8 +55,7 @@ class InterfaceMethodsTests(unittest.TestCase): ...@@ -56,8 +55,7 @@ class InterfaceMethodsTests(unittest.TestCase):
assert client1.ctrl_dir == pathlib.Path("/tmp/foo") assert client1.ctrl_dir == pathlib.Path("/tmp/foo")
assert client2.ctrl_dir == pathlib.Path("/tmp/foo") assert client2.ctrl_dir == pathlib.Path("/tmp/foo")
@anyio_mock.with_anyio() async def test_list_interfaces(self) -> None:
async def test_list_interfaces(self):
""" """
Check list_interfaces() processes lines of names in a list Check list_interfaces() processes lines of names in a list
""" """
...@@ -76,8 +74,7 @@ class InterfaceMethodsTests(unittest.TestCase): ...@@ -76,8 +74,7 @@ class InterfaceMethodsTests(unittest.TestCase):
client.sock.send.assert_called_once_with(b"INTERFACES") client.sock.send.assert_called_once_with(b"INTERFACES")
@anyio_mock.with_anyio() async def test_add_interface(self) -> None:
async def test_add_interface(self):
""" """
Check add_interface() sends the correct arguments Check add_interface() sends the correct arguments
""" """
...@@ -93,10 +90,9 @@ class InterfaceMethodsTests(unittest.TestCase): ...@@ -93,10 +90,9 @@ class InterfaceMethodsTests(unittest.TestCase):
@patch( @patch(
"wpa_supplicant.client.interfaces.InterfaceClient.connect", "wpa_supplicant.client.interfaces.InterfaceClient.connect",
new_callable=anyio_mock.AsyncMock, new_callable=AsyncMock,
) )
@anyio_mock.with_anyio() async def test_connect_interface(self, connect_mock: AsyncMock) -> None:
async def test_connect_interface(self, connect_mock):
""" """
Check connect_interface() returns a connected InterfaceClient Check connect_interface() returns a connected InterfaceClient
""" """
...@@ -116,10 +112,9 @@ class InterfaceMethodsTests(unittest.TestCase): ...@@ -116,10 +112,9 @@ class InterfaceMethodsTests(unittest.TestCase):
@patch( @patch(
"wpa_supplicant.client.interfaces.InterfaceClient.connect", "wpa_supplicant.client.interfaces.InterfaceClient.connect",
new_callable=anyio_mock.AsyncMock, new_callable=AsyncMock,
) )
@anyio_mock.with_anyio() async def test_connect_interface_with_add(self, connect_mock: AsyncMock) -> None:
async def test_connect_interface_with_add(self, connect_mock):
""" """
Check connect_interface() adds the interface when not already managed Check connect_interface() adds the interface when not already managed
""" """
...@@ -140,8 +135,7 @@ class InterfaceMethodsTests(unittest.TestCase): ...@@ -140,8 +135,7 @@ class InterfaceMethodsTests(unittest.TestCase):
self.assertTupleEqual(args[0][0], (b"INTERFACES",)) self.assertTupleEqual(args[0][0], (b"INTERFACES",))
assert args[1][0][0].startswith(b"INTERFACE_ADD enp1s0\t") assert args[1][0][0].startswith(b"INTERFACE_ADD enp1s0\t")
@anyio_mock.with_anyio() async def test_unconnected(self) -> None:
async def test_unconnected(self):
""" """
Check that calling add_interface() on an unconnected client raises RuntimeError Check that calling add_interface() on an unconnected client raises RuntimeError
......
# Copyright 2019-2021 Dom Sekotill <dom.sekotill@kodo.org.uk> # Copyright 2019-2021, 2024 Dom Sekotill <dom.sekotill@kodo.org.uk>
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -17,26 +17,27 @@ Test cases for wpa_supplicant.client.interfaces.InterfaceClient ...@@ -17,26 +17,27 @@ Test cases for wpa_supplicant.client.interfaces.InterfaceClient
""" """
import unittest import unittest
from collections.abc import Iterator
from contextlib import contextmanager from contextlib import contextmanager
from unittest.mock import AsyncMock
from unittest.mock import call from unittest.mock import call
from tests import _anyio as anyio_mock
from wpa_supplicant import config from wpa_supplicant import config
from wpa_supplicant.client import interfaces from wpa_supplicant.client import interfaces
class MethodsTests(unittest.TestCase): class MethodsTests(unittest.IsolatedAsyncioTestCase):
""" """
Tests for InterfaceClient methods Tests for InterfaceClient methods
""" """
def setUp(self): def setUp(self) -> None:
self.client = client = interfaces.InterfaceClient() self.client = client = interfaces.InterfaceClient()
client.sock = anyio_mock.AsyncMock() client.sock = AsyncMock()
client.sock.send.return_value = None client.sock.send.return_value = None
@contextmanager @contextmanager
def subTest(self, *args, reset=[], **kwargs): def subTest(self, *args: object, reset: list[AsyncMock] = [], **kwargs: object) -> Iterator[None]:
with super().subTest(*args, **kwargs): with super().subTest(*args, **kwargs):
try: try:
yield yield
...@@ -44,8 +45,7 @@ class MethodsTests(unittest.TestCase): ...@@ -44,8 +45,7 @@ class MethodsTests(unittest.TestCase):
for mock in reset: for mock in reset:
mock.reset_mock() mock.reset_mock()
@anyio_mock.with_anyio() async def test_scan(self) -> None:
async def test_scan(self):
""" """
Check that a scan command waits for a notification then terminates correctly Check that a scan command waits for a notification then terminates correctly
""" """
...@@ -64,8 +64,7 @@ class MethodsTests(unittest.TestCase): ...@@ -64,8 +64,7 @@ class MethodsTests(unittest.TestCase):
self.assertIsInstance(bss, dict) self.assertIsInstance(bss, dict)
self.assertIn("good", bss) self.assertIn("good", bss)
@anyio_mock.with_anyio() async def test_set_network(self) -> None:
async def test_set_network(self):
""" """
Check that set_network sends values to the daemon and raises TypeError for bad types Check that set_network sends values to the daemon and raises TypeError for bad types
""" """
...@@ -105,8 +104,7 @@ class MethodsTests(unittest.TestCase): ...@@ -105,8 +104,7 @@ class MethodsTests(unittest.TestCase):
self.assertRaises(TypeError): self.assertRaises(TypeError):
await client.set_network("0", "key_mgmt", 1) await client.set_network("0", "key_mgmt", 1)
@anyio_mock.with_anyio() async def test_add_network(self) -> None:
async def test_add_network(self):
""" """
Check that add_network adds a new network and configures it Check that add_network adds a new network and configures it
""" """
......
""" """
Async control of WPA-Supplicant from a Python process Async control of WPA-Supplicant from a Python process
# Copyright 2019-2021 Dom Sekotill <dom.sekotill@kodo.org.uk> # Copyright 2019-2021, 2024 Dom Sekotill <dom.sekotill@kodo.org.uk>
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -15,5 +15,3 @@ Async control of WPA-Supplicant from a Python process ...@@ -15,5 +15,3 @@ Async control of WPA-Supplicant from a Python process
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
__version__ = "0.3.0"
# Copyright 2021 Dom Sekotill <dom.sekotill@kodo.org.uk> # Copyright 2021, 2024 Dom Sekotill <dom.sekotill@kodo.org.uk>
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -19,47 +19,15 @@ Work-arounds for lack of AF_UNIX datagram socket support in Anyio ...@@ -19,47 +19,15 @@ Work-arounds for lack of AF_UNIX datagram socket support in Anyio
from __future__ import annotations from __future__ import annotations
import errno import errno
import os
import socket
import tempfile import tempfile
from contextlib import suppress from contextlib import suppress
from os import PathLike from os import PathLike
from typing import Any
from typing import Callable
from typing import Coroutine
from typing import Dict
from typing import Protocol
from typing import Union
from typing import cast
import sniffio from anyio import create_connected_unix_datagram_socket
from anyio.abc import ConnectedUNIXDatagramSocket as DatagramSocket
ConnectorFn = Callable[[str, str], Coroutine[Any, Any, 'DatagramSocket']]
connectors: Dict[str, ConnectorFn] = {} async def connect_unix_datagram(path: str | PathLike[str]) -> DatagramSocket:
class DatagramSocket(Protocol):
@property
def _raw_socket(self) -> socket.socket: ...
async def aclose(self) -> None: ...
async def receive(self) -> bytes: ...
async def send(self, item: bytes) -> None: ...
class ConnectedUNIXMixin:
async def aclose(self: DatagramSocket) -> None:
path = self._raw_socket.getsockname()
await super().aclose() # type: ignore # Mypy doesn't handle super() well in mixins
os.unlink(path)
async def connect_unix_datagram(path: Union[str, PathLike[str]]) -> DatagramSocket:
""" """
Return an AnyIO socket connected to a Unix datagram socket Return an AnyIO socket connected to a Unix datagram socket
...@@ -68,82 +36,7 @@ async def connect_unix_datagram(path: Union[str, PathLike[str]]) -> DatagramSock ...@@ -68,82 +36,7 @@ async def connect_unix_datagram(path: Union[str, PathLike[str]]) -> DatagramSock
for _ in range(10): for _ in range(10):
fname = tempfile.mktemp(suffix=".sock", prefix="wpa_ctrl.") fname = tempfile.mktemp(suffix=".sock", prefix="wpa_ctrl.")
with suppress(FileExistsError): with suppress(FileExistsError):
async_lib = sniffio.current_async_library() return await create_connected_unix_datagram_socket(path, local_path=fname)
connector = connectors[async_lib]
return await connector(fname, os.fspath(path))
raise FileExistsError( raise FileExistsError(
errno.EEXIST, "No usable temporary filename found", errno.EEXIST, "No usable temporary filename found",
) )
try:
import trio
except ImportError: ...
else:
from anyio._backends import _trio
class TrioConnectedUNIXSocket(ConnectedUNIXMixin, _trio.ConnectedUDPSocket):
...
async def trio_connect_unix_datagram(
local_path: str,
remote_path: str,
) -> TrioConnectedUNIXSocket:
sock = trio.socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
await sock.bind(local_path)
try:
await sock.connect(remote_path)
except BaseException: # pragma: no cover
sock.close()
raise
else:
return TrioConnectedUNIXSocket(sock)
connectors['trio'] = trio_connect_unix_datagram
# asyncio is in the stdlib, but lets make the layout match trio 😉
try:
import asyncio
except ImportError: ...
else:
from anyio._backends import _asyncio
class AsyncioConnectedUNIXSocket(ConnectedUNIXMixin, _asyncio.ConnectedUDPSocket):
...
async def asyncio_connect_unix_datagram(
local_path: str,
remote_path: str,
) -> AsyncioConnectedUNIXSocket:
await asyncio.sleep(0.0)
loop = asyncio.get_running_loop()
sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
sock.setblocking(False)
sock.bind(local_path)
while True:
try:
sock.connect(remote_path)
except BlockingIOError:
future: asyncio.Future[None] = asyncio.Future()
loop.add_writer(sock, future.set_result, None)
future.add_done_callback(lambda _: loop.remove_writer(sock))
await future
except BaseException:
sock.close()
raise
else:
break
transport_, protocol_ = await asyncio.get_running_loop().create_datagram_endpoint(
_asyncio.DatagramProtocol,
sock=sock,
)
transport = cast(asyncio.DatagramTransport, transport_)
protocol = cast(_asyncio.DatagramProtocol, protocol_)
if protocol.exception:
transport.close()
raise protocol.exception
return AsyncioConnectedUNIXSocket(transport, protocol)
connectors['asyncio'] = asyncio_connect_unix_datagram
# Copyright 2019-2021 Dom Sekotill <dom.sekotill@kodo.org.uk> # Copyright 2019-2021, 2024 Dom Sekotill <dom.sekotill@kodo.org.uk>
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -16,14 +16,7 @@ ...@@ -16,14 +16,7 @@
WPA-Supplicant client classes WPA-Supplicant client classes
""" """
from ._global import GlobalClient from ._global import GlobalClient as GlobalClient
from .base import BaseClient from .base import BaseClient as BaseClient
from .consts import * from .consts import *
from .consts import __all__ as _consts_names from .interfaces import InterfaceClient as InterfaceClient
from .interfaces import InterfaceClient
__all__ = _consts_names + (
'BaseClient',
'GlobalClient',
'InterfaceClient',
)
# Copyright 2019-2021 Dom Sekotill <dom.sekotill@kodo.org.uk> # Copyright 2019-2021, 2024 Dom Sekotill <dom.sekotill@kodo.org.uk>
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -20,7 +20,6 @@ from __future__ import annotations ...@@ -20,7 +20,6 @@ from __future__ import annotations
import pathlib import pathlib
from os import PathLike from os import PathLike
from typing import Set
from . import consts from . import consts
from .base import BaseClient from .base import BaseClient
...@@ -40,7 +39,7 @@ class GlobalClient(BaseClient): ...@@ -40,7 +39,7 @@ class GlobalClient(BaseClient):
await super().connect(path) await super().connect(path)
self.ctrl_dir = path.parent self.ctrl_dir = path.parent
async def list_interfaces(self) -> Set[str]: async def list_interfaces(self) -> set[str]:
""" """
Return a set of the interfaces currently managed by the daemon Return a set of the interfaces currently managed by the daemon
""" """
...@@ -52,14 +51,10 @@ class GlobalClient(BaseClient): ...@@ -52,14 +51,10 @@ class GlobalClient(BaseClient):
""" """
Add a network interface to the daemon's control interfaces Add a network interface to the daemon's control interfaces
""" """
if self.ctrl_dir:
ctrl_iface = f"DIR={self.ctrl_dir} GROUP={self.ctrl_dir.group()}"
else:
# RuntimeError should be raised by send_command() as connect() does not appear
# to have been called; set ctrl_iface to any string
ctrl_iface = ""
await self.send_command( await self.send_command(
consts.COMMAND_INTERFACE_ADD, ifname, "", driver, ctrl_iface, driver_param, consts.COMMAND_INTERFACE_ADD, ifname, "", driver,
f"DIR={self.ctrl_dir} GROUP={self.ctrl_dir.group()}" if self.ctrl_dir else "",
driver_param,
) )
assert self.ctrl_dir is not None, \ assert self.ctrl_dir is not None, \
"RuntimeError should be raised for sends on unconnected clients; " \ "RuntimeError should be raised for sends on unconnected clients; " \
......
# Copyright 2019-2021 Dom Sekotill <dom.sekotill@kodo.org.uk> # Copyright 2019-2021, 2024 Dom Sekotill <dom.sekotill@kodo.org.uk>
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
""" """
This module provides a base WPA-Supplicant client implementation Base implementation for WPA-Supplicant client classes
""" """
from __future__ import annotations from __future__ import annotations
...@@ -21,17 +21,15 @@ from __future__ import annotations ...@@ -21,17 +21,15 @@ from __future__ import annotations
import enum import enum
import logging import logging
import os import os
import sys
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from re import compile as regex from re import compile as regex
from types import TracebackType as Traceback from types import TracebackType as Traceback
from typing import Any
from typing import AsyncContextManager
from typing import Callable from typing import Callable
from typing import Dict
from typing import Optional from typing import Optional
from typing import Tuple from typing import Tuple
from typing import Type
from typing import TypeVar from typing import TypeVar
from typing import Union
from typing import overload from typing import overload
import anyio import anyio
...@@ -53,7 +51,7 @@ class EventPriority(enum.IntEnum): ...@@ -53,7 +51,7 @@ class EventPriority(enum.IntEnum):
Event Message priorities Event Message priorities
""" """
def get_logger_level(self, *, _mapping: Dict[EventPriority, int] = {}) -> int: def get_logger_level(self, *, _mapping: dict[EventPriority, int] = {}) -> int:
""" """
Return a logging level matching the `wpa_supplicant` priority level Return a logging level matching the `wpa_supplicant` priority level
""" """
...@@ -93,15 +91,15 @@ class BaseClient: ...@@ -93,15 +91,15 @@ class BaseClient:
event_regex = regex(r"<([0-9]+)>(?:((?:CTRL|WPS|AP|P2P)-[A-Z0-9-]+)(?:\s|$))?(.+)?") event_regex = regex(r"<([0-9]+)>(?:((?:CTRL|WPS|AP|P2P)-[A-Z0-9-]+)(?:\s|$))?(.+)?")
def __init__(self, *, logger: Optional[logging.Logger] = None): def __init__(self, *, logger: logging.Logger | None = None) -> None:
self.logger = logger or logging.getLogger(__package__) self.logger = logger or logging.getLogger(__package__)
self.ctrl_dir = None self.ctrl_dir = None
self.sock: Optional[DatagramSocket] = None self.sock: DatagramSocket | None = None
self._lock = anyio.Lock() self._lock = anyio.Lock()
self._condition = anyio.Condition() self._condition = anyio.Condition()
self._handler_active = False self._handler_active = False
self._reply: Union[_ReplyState, str] = _ReplyState.NOTHING self._reply: _ReplyState | str = _ReplyState.NOTHING
self._event: Optional[EventInfo] self._event: EventInfo | None
self._eventcount = 0 self._eventcount = 0
async def __aenter__(self) -> BaseClient: async def __aenter__(self) -> BaseClient:
...@@ -109,9 +107,9 @@ class BaseClient: ...@@ -109,9 +107,9 @@ class BaseClient:
async def __aexit__( async def __aexit__(
self, self,
_et: Optional[Type[BaseException]], _et: type[BaseException] | None,
_e: Optional[BaseException], _e: BaseException | None,
_tb: Optional[Traceback], _tb: Traceback | None,
) -> None: ) -> None:
await self.disconnect() await self.disconnect()
...@@ -159,8 +157,8 @@ class BaseClient: ...@@ -159,8 +157,8 @@ class BaseClient:
*args: str, *args: str,
separator: str = consts.SEPARATOR_TAB, separator: str = consts.SEPARATOR_TAB,
expect: str = consts.RESPONSE_OK, expect: str = consts.RESPONSE_OK,
convert: Optional[Callable[[str], T]] = None, convert: Callable[[str], T] | None = None,
) -> Optional[T]: ) -> T | None:
""" """
Send a message and await a response Send a message and await a response
...@@ -215,11 +213,28 @@ class BaseClient: ...@@ -215,11 +213,28 @@ class BaseClient:
) )
return None return None
def attach(self) -> AsyncContextManager[None]: @asynccontextmanager
async def attach(self) -> AsyncIterator[None]:
""" """
Return a context manager that handles attaching to the daemon's message queue Return a context manager that handles attaching to the daemon's message queue
""" """
return self._AttachContext(self) assert self._eventcount >= 0
self._eventcount += 1
if self._eventcount == 1:
await self.send_command(consts.COMMAND_ATTACH)
try:
yield
except:
if __debug__:
exc_type, *_ = sys.exc_info()
assert exc_type is not None
self.logger.debug("Detaching due to %s", exc_type.__name__)
raise
finally:
assert self._eventcount > 0
self._eventcount -= 1
if self._eventcount == 0:
await self.send_command(consts.COMMAND_DETACH)
async def event(self, *events: str) -> EventInfo: async def event(self, *events: str) -> EventInfo:
""" """
...@@ -262,45 +277,25 @@ class BaseClient: ...@@ -262,45 +277,25 @@ class BaseClient:
raise anyio.ClosedResourceError raise anyio.ClosedResourceError
self.logger.debug("Received: %s", repr(msg)) self.logger.debug("Received: %s", repr(msg))
match = self.event_regex.match(msg)
match self._parse_message(msg):
# If matched, it is an event case str(msg):
if match: if self._reply is _ReplyState.AWAITING:
prio_, name, msg = match.groups() self._reply = msg
prio = EventPriority(int(prio_)) else:
self.logger.warning("Unexpected response message: %s", msg)
# If it's not an event, check whether a reply to a sent message is expected case [prio, name, message] if name is None:
elif self._reply is not _ReplyState.AWAITING: # Unnamed events are just for logging
self.logger.warning("Unexpected response message: %s", msg) assert message is not None, "empty log message received"
return self.logger.log(prio.get_logger_level(), message)
else: case [prio, str(name), message]:
self._reply = msg self._event = (prio, name, message)
return case _: # pragma: no-cover
raise AssertionError("unexpected return from BaseClient._parse_message()")
# Unnamed events are just for logging
if not name: @classmethod
self.logger.log(prio.get_logger_level(), msg) def _parse_message(cls, message: str) -> tuple[EventPriority, str|None, str|None] | str:
return if not (rematch := cls.event_regex.match(message)):
return message
self._event = (prio, name, msg or None) prio_, name, msg = rematch.groups()
return EventPriority(int(prio_)), name, msg
class _AttachContext:
def __init__(self, client: BaseClient):
self.client = client
async def __aenter__(self) -> None:
client = self.client
assert client._eventcount >= 0
if client._eventcount == 0:
await client.send_command(consts.COMMAND_ATTACH)
client._eventcount += 1
async def __aexit__(self, *exc_info: Any) -> None:
client = self.client
assert client._eventcount > 0
client._eventcount -= 1
if client._eventcount == 0:
if __debug__: # On it's own for compiler optimisation
if exc_info[0]:
client.logger.debug(f"Detaching due to {exc_info[0].__name__}")
await client.send_command(consts.COMMAND_DETACH)
# Copyright 2019-2021 Dom Sekotill <dom.sekotill@kodo.org.uk> # Copyright 2019-2021, 2024 Dom Sekotill <dom.sekotill@kodo.org.uk>
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -20,7 +20,6 @@ from __future__ import annotations ...@@ -20,7 +20,6 @@ from __future__ import annotations
from itertools import count from itertools import count
from os import PathLike from os import PathLike
from typing import Any
from typing import AsyncGenerator from typing import AsyncGenerator
from typing import Dict from typing import Dict
...@@ -39,12 +38,15 @@ class InterfaceClient(BaseClient): ...@@ -39,12 +38,15 @@ class InterfaceClient(BaseClient):
name = None name = None
async def connect(self, path: PathLike[str]) -> None: async def connect(self, path: PathLike[str]) -> None:
"""
Connect to an interface UNIX port
"""
await super().connect(path) await super().connect(path)
self.name = await self.send_command(consts.COMMAND_IFNAME, convert=str) self.name = await self.send_command(consts.COMMAND_IFNAME, convert=str)
async def scan(self) -> AsyncGenerator[StringMap, None]: async def scan(self) -> AsyncGenerator[StringMap, None]:
""" """
Iteratively produces the details of all detectable IEEE 802.11 BSS Yield the details of all detectable IEEE 802.11 BSS
(WiFi Access Points to you and me) (WiFi Access Points to you and me)
""" """
...@@ -59,7 +61,7 @@ class InterfaceClient(BaseClient): ...@@ -59,7 +61,7 @@ class InterfaceClient(BaseClient):
return return
yield bss yield bss
async def add_network(self, configuration: Dict[str, Any]) -> int: async def add_network(self, configuration: dict[str, object]) -> int:
"""Add a new network configuration""" """Add a new network configuration"""
netid = await self.send_command(consts.COMMAND_ADD_NETWORK, convert=str) netid = await self.send_command(consts.COMMAND_ADD_NETWORK, convert=str)
for var, val in configuration.items(): for var, val in configuration.items():
...@@ -67,7 +69,7 @@ class InterfaceClient(BaseClient): ...@@ -67,7 +69,7 @@ class InterfaceClient(BaseClient):
await self.send_command(consts.COMMAND_ENABLE_NETWORK, netid) await self.send_command(consts.COMMAND_ENABLE_NETWORK, netid)
return int(netid) return int(netid)
async def set_network(self, netid: str, variable: str, value: Any) -> None: async def set_network(self, netid: str, variable: str, value: object) -> None:
"""Set a network configuration option""" """Set a network configuration option"""
if not isinstance(value, config.get_type(variable)): if not isinstance(value, config.get_type(variable)):
raise TypeError(f"Wrong type for {variable}: {value!r}") raise TypeError(f"Wrong type for {variable}: {value!r}")
...@@ -84,4 +86,4 @@ def _kv2dict(keyvalues: str) -> StringMap: ...@@ -84,4 +86,4 @@ def _kv2dict(keyvalues: str) -> StringMap:
""" """
Convert a list of line-terminated "key=value" substrings into a dictionary Convert a list of line-terminated "key=value" substrings into a dictionary
""" """
return dict(kv.split("=", 1) for kv in keyvalues.splitlines()) # type: ignore return dict(kv.split("=", 1) for kv in keyvalues.splitlines())
# Copyright 2019-2021 Dom Sekotill <dom.sekotill@kodo.org.uk> # Copyright 2019-2021, 2024 Dom Sekotill <dom.sekotill@kodo.org.uk>
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -18,7 +18,6 @@ Helpers for network configuration ...@@ -18,7 +18,6 @@ Helpers for network configuration
from enum import Enum from enum import Enum
from enum import auto from enum import auto
from typing import Any
from typing import Callable from typing import Callable
from typing import Dict from typing import Dict
from typing import Optional from typing import Optional
...@@ -67,7 +66,7 @@ def get_type(variable: str) -> type: ...@@ -67,7 +66,7 @@ def get_type(variable: str) -> type:
class _UnknownTypeMeta(type): class _UnknownTypeMeta(type):
def __instancecheck__(cls, instance: Any) -> bool: def __instancecheck__(cls, instance: object) -> bool:
return isinstance(instance, (str, int)) return isinstance(instance, (str, int))
...@@ -82,7 +81,7 @@ class ConfigEnum(Enum): ...@@ -82,7 +81,7 @@ class ConfigEnum(Enum):
return str(self.value) return str(self.value)
@staticmethod @staticmethod
def _generate_next_value_(name: str, *_: Any) -> str: def _generate_next_value_(name: str, start: object, count: object, last_values: object) -> str:
return name.replace("_", "-") return name.replace("_", "-")
......