# Copyright 2019-2021, 2024 Dom Sekotill # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ This module provides a base WPA-Supplicant client implementation """ from __future__ import annotations import enum import logging import os from re import compile as regex from types import TracebackType as Traceback from typing import AsyncContextManager from typing import Callable from typing import Optional from typing import Tuple from typing import TypeVar from typing import overload import anyio from .. import errors from .._anyio import DatagramSocket from .._anyio import connect_unix_datagram from . import consts T = TypeVar('T') EventInfo = Tuple['EventPriority', str, Optional[str]] # 128kB (actual max size slightly less than this) MAX_DGRAM_READ = 2 ** 17 class EventPriority(enum.IntEnum): """ Event Message priorities """ def get_logger_level(self, *, _mapping: dict[EventPriority, int] = {}) -> int: """ Return a logging level matching the `wpa_supplicant` priority level """ if not _mapping: cls = type(self) _mapping.update({ cls.MSGDUMP: logging.DEBUG, cls.DEBUG: logging.DEBUG, cls.INFO: logging.INFO, cls.NOTICE: logging.INFO, cls.WARNING: logging.WARNING, cls.ERROR: logging.ERROR, }) return _mapping[self] MSGDUMP = 0 DEBUG = 1 INFO = 2 NOTICE = 3 WARNING = 4 ERROR = 5 class _ReplyState(enum.Enum): NOTHING = enum.auto() AWAITING = enum.auto() class BaseClient: """ A client for controlling a WPA-Supplicant daemon over a control socket This class is a naïve implementation. You probably want GlobalClient and InterfaceClient. """ event_regex = regex(r"<([0-9]+)>(?:((?:CTRL|WPS|AP|P2P)-[A-Z0-9-]+)(?:\s|$))?(.+)?") def __init__(self, *, logger: logging.Logger | None = None) -> None: self.logger = logger or logging.getLogger(__package__) self.ctrl_dir = None self.sock: DatagramSocket | None = None self._lock = anyio.Lock() self._condition = anyio.Condition() self._handler_active = False self._reply: _ReplyState | str = _ReplyState.NOTHING self._event: EventInfo | None self._eventcount = 0 async def __aenter__(self) -> BaseClient: return self async def __aexit__( self, _et: type[BaseException] | None, _e: BaseException | None, _tb: Traceback | None, ) -> None: await self.disconnect() async def connect(self, path: os.PathLike[str]) -> None: """ Connect to a WPA-Supplicant daemon through the given address """ if self.sock is not None: raise RuntimeError("cannot connect to multiple daemons") with anyio.fail_after(1.0): self.sock = await connect_unix_datagram(os.fspath(path)) await self.send_command(consts.COMMAND_PING, expect=consts.RESPONSE_PONG) async def disconnect(self) -> None: """ Disconnect from the connected daemon, if connected """ if self.sock: await self.sock.aclose() @overload async def send_command( self, message: str, *args: str, separator: str = consts.SEPARATOR_TAB, expect: str = consts.RESPONSE_OK, convert: Callable[[str], T], ) -> T: ... @overload async def send_command( self, message: str, *args: str, separator: str = consts.SEPARATOR_TAB, expect: str = consts.RESPONSE_OK, convert: None = None, ) -> None: ... async def send_command( self, message: str, *args: str, separator: str = consts.SEPARATOR_TAB, expect: str = consts.RESPONSE_OK, convert: Callable[[str], T] | None = None, ) -> T | None: """ Send a message and await a response If one of the failure responses described below are returned from the daemon an exception is raised. If 'convert' is a callable the response will be passed to it as an argument and the result returned. If the response matches 'expect' (default: "OK") None is returned. Otherwise an UnexpectedResponseError is raised. The standard failure responses are: FAIL: The command may have had bad arguments or could not be accepted. Raises CommandFailed UNKNOWN COMMAND: Either the command was not known or argument where supplied to a command that does not take arguments, or vice versa. Raises ValueError """ if self.sock is None: raise RuntimeError("Client is not connected") if args: message = f"{message} {separator.join(args)}" msgbytes = message.encode() async with self._lock: # Set awaiting state for reply *before* sending, or currently active event # listeners will cause a race condition self._reply = awaiting = _ReplyState.AWAITING self.logger.debug("Sending: %s", repr(message)) await self.sock.send(msgbytes) while self._reply is awaiting: await self._msg_notification() # Make sure the reply is retrieved within the lock context assert isinstance(self._reply, str) resp = self._reply if resp == consts.RESPONSE_FAIL: raise errors.CommandFailed(f"command returned FAIL: {message!r}") if resp == consts.RESPONSE_UNKNOWN_COMMAND: raise ValueError(f"Unknown command: {message!r}") if convert: return convert(resp) if resp != expect: raise errors.UnexpectedResponseError( f"Unexpected response to {message!r}: {resp!r}", ) return None def attach(self) -> AsyncContextManager[None]: """ Return a context manager that handles attaching to the daemon's message queue """ return self._AttachContext(self) async def event(self, *events: str) -> EventInfo: """ Await any of the given set of events """ async with self.attach(): while True: await self._msg_notification() if self._event is None: continue if self._event[1] in events: return self._event async def _msg_notification(self) -> None: condition = self._condition async with condition: # If another task is handling responses already, await a notification from it if self._handler_active: await condition.wait() return self._handler_active = True condition.release() try: await self._get_message() finally: await condition.acquire() self._handler_active = False condition.notify_all() async def _get_message(self) -> None: self._event = None assert self.sock is not None, \ "RuntimeError should have been raised by a public method if self.sock is None" try: msg = (await self.sock.receive()).decode().strip() except (UnicodeDecodeError, anyio.EndOfStream): await self.sock.aclose() raise anyio.ClosedResourceError self.logger.debug("Received: %s", repr(msg)) match = self.event_regex.match(msg) # If matched, it is an event if match: prio_, name, msg = match.groups() prio = EventPriority(int(prio_)) # If it's not an event, check whether a reply to a sent message is expected elif self._reply is not _ReplyState.AWAITING: self.logger.warning("Unexpected response message: %s", msg) return else: self._reply = msg return # Unnamed events are just for logging if not name: self.logger.log(prio.get_logger_level(), msg) return self._event = (prio, name, msg or None) class _AttachContext: def __init__(self, client: BaseClient) -> None: self.client = client async def __aenter__(self) -> None: client = self.client assert client._eventcount >= 0 if client._eventcount == 0: await client.send_command(consts.COMMAND_ATTACH) client._eventcount += 1 async def __aexit__(self, exc_type: type[BaseException]|None, *exc_info: object) -> None: client = self.client assert client._eventcount > 0 client._eventcount -= 1 if client._eventcount == 0: if __debug__ and exc_type: client.logger.debug(f"Detaching due to {exc_type.__name__}") await client.send_command(consts.COMMAND_DETACH)