Skip to content
base.py 7.86 KiB
Newer Older
#  Copyright 2019-2021, 2024  Dom Sekotill <dom.sekotill@kodo.org.uk>
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

"""
Dom Sekotill's avatar
Dom Sekotill committed
Base implementation for WPA-Supplicant client classes
from __future__ import annotations

import enum
import logging
import os
import sys
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from re import compile as regex
from types import TracebackType as Traceback
from typing import Callable
from typing import Optional
from typing import Tuple
from typing import TypeVar
from typing import overload
from .. import errors
from .._anyio import DatagramSocket
from .._anyio import connect_unix_datagram
from . import consts
T = TypeVar('T')
EventInfo = Tuple['EventPriority', str, Optional[str]]

# 128kB (actual max size slightly less than this)
MAX_DGRAM_READ = 2 ** 17


class EventPriority(enum.IntEnum):
	"""
	Event Message priorities
	"""

Dom Sekotill's avatar
Dom Sekotill committed
	def get_logger_level(self, *, _mapping: dict[EventPriority, int] = {}) -> int:
		"""
		Return a logging level matching the `wpa_supplicant` priority level
		"""
		if not _mapping:
			_mapping.update({
				cls.MSGDUMP: logging.DEBUG,
				cls.DEBUG: logging.DEBUG,
				cls.INFO: logging.INFO,
				cls.NOTICE: logging.INFO,
				cls.WARNING: logging.WARNING,
				cls.ERROR: logging.ERROR,
			})
		return _mapping[self]

	MSGDUMP = 0
	DEBUG = 1
	INFO = 2
	NOTICE = 3
	WARNING = 4
	ERROR = 5


class _ReplyState(enum.Enum):

	NOTHING = enum.auto()
	AWAITING = enum.auto()


class BaseClient:
	"""
	A client for controlling a WPA-Supplicant daemon over a control socket

	This class is a naïve implementation. You probably want GlobalClient and
	InterfaceClient.
	"""

	event_regex = regex(r"<([0-9]+)>(?:((?:CTRL|WPS|AP|P2P)-[A-Z0-9-]+)(?:\s|$))?(.+)?")
Dom Sekotill's avatar
Dom Sekotill committed
	def __init__(self, *, logger: logging.Logger | None = None) -> None:
		self.logger = logger or logging.getLogger(__package__)
		self.ctrl_dir = None
Dom Sekotill's avatar
Dom Sekotill committed
		self.sock: DatagramSocket | None = None
		self._lock = anyio.Lock()
		self._condition = anyio.Condition()
		self._handler_active = False
Dom Sekotill's avatar
Dom Sekotill committed
		self._reply: _ReplyState | str = _ReplyState.NOTHING
		self._event: EventInfo | None
		self._eventcount = 0

	async def __aenter__(self) -> BaseClient:
		return self

	async def __aexit__(
		self,
Dom Sekotill's avatar
Dom Sekotill committed
		_et: type[BaseException] | None,
		_e: BaseException | None,
		_tb: Traceback | None,
		await self.disconnect()

	async def connect(self, path: os.PathLike[str]) -> None:
		"""
		Connect to a WPA-Supplicant daemon through the given address
		"""
		if self.sock is not None:
			raise RuntimeError("cannot connect to multiple daemons")
		with anyio.fail_after(1.0):
			self.sock = await connect_unix_datagram(os.fspath(path))
			await self.send_command(consts.COMMAND_PING, expect=consts.RESPONSE_PONG)
	async def disconnect(self) -> None:
		"""
		Disconnect from the connected daemon, if connected
		"""
		if self.sock:
			await self.sock.aclose()
Dom Sekotill's avatar
Dom Sekotill committed
	async def send_command(
		self,
		message: str,
		*args: str,
		separator: str = consts.SEPARATOR_TAB,
Dom Sekotill's avatar
Dom Sekotill committed
		expect: str = consts.RESPONSE_OK,
		convert: Callable[[str], T],
	) -> T: ...

	@overload
	async def send_command(
		self,
		message: str,
		*args: str,
		separator: str = consts.SEPARATOR_TAB,
		expect: str = consts.RESPONSE_OK,
		convert: None = None,
	) -> None: ...

	async def send_command(
		self,
		message: str,
		*args: str,
		separator: str = consts.SEPARATOR_TAB,
		expect: str = consts.RESPONSE_OK,
Dom Sekotill's avatar
Dom Sekotill committed
		convert: Callable[[str], T] | None = None,
	) -> T | None:
		"""
		Send a message and await a response

		If one of the failure responses described below are returned from the
		daemon an exception is raised.  If 'convert' is a callable the response
		will be passed to it as an argument and the result returned.  If the
		response matches 'expect' (default: "OK") None is returned.  Otherwise
		an UnexpectedResponseError is raised.

		The standard failure responses are:

		FAIL:
		  The command may have had bad arguments or could not be accepted.
		  Raises CommandFailed

		UNKNOWN COMMAND:
		  Either the command was not known or argument where supplied to
		  a command that does not take arguments, or vice versa.
		  Raises ValueError
		"""
		if self.sock is None:
			raise RuntimeError("Client is not connected")

		if args:
			message = f"{message} {separator.join(args)}"
		msgbytes = message.encode()

		async with self._lock:
			# Set awaiting state for reply *before* sending, or currently active event
			# listeners will cause a race condition
			self._reply = awaiting = _ReplyState.AWAITING

			self.logger.debug("Sending: %s", repr(message))
			await self.sock.send(msgbytes)
			while self._reply is awaiting:
				await self._msg_notification()
			# Make sure the reply is retrieved within the lock context
			assert isinstance(self._reply, str)
			resp = self._reply
		if resp == consts.RESPONSE_FAIL:
			raise errors.CommandFailed(f"command returned FAIL: {message!r}")
		if resp == consts.RESPONSE_UNKNOWN_COMMAND:
			raise ValueError(f"Unknown command: {message!r}")
		if convert:
			return convert(resp)
		if resp != expect:
			raise errors.UnexpectedResponseError(
				f"Unexpected response to {message!r}: {resp!r}",
Dom Sekotill's avatar
Dom Sekotill committed
			)
		return None

	@asynccontextmanager
	async def attach(self) -> AsyncIterator[None]:
		"""
		Return a context manager that handles attaching to the daemon's message queue
		"""
		assert self._eventcount >= 0
		self._eventcount += 1
		if self._eventcount == 1:
			await self.send_command(consts.COMMAND_ATTACH)
		try:
			yield
		except:
			if __debug__:
				exc_type, *_ = sys.exc_info()
				assert exc_type is not None
				self.logger.debug("Detaching due to %s", exc_type.__name__)
			raise
		finally:
			assert self._eventcount > 0
			self._eventcount -= 1
			if self._eventcount == 0:
				await self.send_command(consts.COMMAND_DETACH)
	async def event(self, *events: str) -> EventInfo:
		"""
		Await any of the given set of events
		"""
		async with self.attach():
			while True:
				await self._msg_notification()
				if self._event is None:
					continue
				if self._event[1] in events:
					return self._event

	async def _msg_notification(self) -> None:
		condition = self._condition
		async with condition:
			# If another task is handling responses already, await a notification from it
			if self._handler_active:
				await condition.wait()
			self._handler_active = True
			condition.release()
			try:
				await self._get_message()
			finally:
				await condition.acquire()
				self._handler_active = False
			condition.notify_all()

	async def _get_message(self) -> None:
		self._event = None
		assert self.sock is not None, \
			"RuntimeError should have been raised by a public method if self.sock is None"

		try:
			msg = (await self.sock.receive()).decode().strip()
		except (UnicodeDecodeError, anyio.EndOfStream):
			await self.sock.aclose()
			raise anyio.ClosedResourceError

		self.logger.debug("Received: %s", repr(msg))
		match = self.event_regex.match(msg)

		# If matched, it is an event
		if match:
			prio_, name, msg = match.groups()
			prio = EventPriority(int(prio_))

		# If it's not an event, check whether a reply to a sent message is expected
		elif self._reply is not _ReplyState.AWAITING:
			self.logger.warning("Unexpected response message: %s", msg)
			return
		else:
			self._reply = msg
			return
		# Unnamed events are just for logging
			self.logger.log(prio.get_logger_level(), msg)
			return

		self._event = (prio, name, msg or None)