Skip to content
base.py 7.28 KiB
Newer Older
#  Copyright 2019  Dom Sekotill <dom.sekotill@kodo.org.uk>
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

"""
This module provides a base WPA-Supplicant client implementation
"""

import contextlib
import enum
import logging
import pathlib
import os
from re import compile as regex
from typing import Any, Callable, Optional, Sequence, Tuple
from . import consts
from .. import errors, util
from ..types import PathLike
# 128kB (actual max size slightly less than this)
MAX_DGRAM_READ = 2 ** 17


class EventPriority(enum.IntEnum):
	"""
	Event Message priorities
	"""

	def get_logger_level(self, *, _mapping={}):
		"""
		Return a logging level matching the `wpa_supplicant` priority level
		"""
		if not _mapping:
Dom Sekotill's avatar
Dom Sekotill committed
			# fmt: off
			_mapping.update({
				self.MSGDUMP: logging.DEBUG,
				self.DEBUG: logging.DEBUG,
				self.INFO: logging.INFO,
				self.NOTICE: logging.INFO,
				self.WARNING: logging.WARNING,
				self.ERROR: logging.ERROR,
			})
Dom Sekotill's avatar
Dom Sekotill committed
			# fmt: on
		return _mapping[self]

	MSGDUMP = 0
	DEBUG = 1
	INFO = 2
	NOTICE = 3
	WARNING = 4
	ERROR = 5


class BaseClient:
	"""
	A client for controlling a WPA-Supplicant daemon over a control socket

	This class is a naïve implementation. You probably want MasterClient and 
	InterfaceClient.
	"""

	event_regex = regex(r"<([0-9]+)>(?:((?:CTRL|WPS|AP|P2P)-[A-Z0-9-]+)(?:\s|$))?(.+)?")

	def __init__(self, *, logger=None):
		self.logger = logger or logging.getLogger(__package__)
		self.ctrl_dir = None
		self.sock = None
		self._lock = anyio.create_lock()
		self._reply = ReplyManager()
		self._eventqueues = dict()
		self._eventcount = 0

	async def __aenter__(self):
		return self

	async def __aexit__(self, *exc_info):
		await self.disconnect()

	async def connect(self, path: PathLike):
		"""
		Connect to a WPA-Supplicant daemon through the given address
		"""
		if self.sock is not None:
			raise RuntimeError(f"cannot connect to multiple daemons")

		if not isinstance(path, pathlib.Path):
			path = pathlib.Path(os.fspath(path))

		async with anyio.fail_after(1):
			self.sock = await util.connect_unix_datagram(path.as_posix())
			await self.send_command(consts.COMMAND_PING, expect=consts.RESPONSE_PONG)

	async def disconnect(self):
		"""
		Disconnect from the connected daemon, if connected
		"""
		if self.sock:
			await self.sock.close()

Dom Sekotill's avatar
Dom Sekotill committed
	async def send_command(
		self,
		message: str,
		*args: str,
		separator: str = consts.SEPARATOR_TAB,
Dom Sekotill's avatar
Dom Sekotill committed
		expect: str = consts.RESPONSE_OK,
		convert: Optional[Callable] = None,
	) -> Any:
		"""
		Send a message and await a response

		If one of the failure responses described below are returned from the 
		daemon an exception is raised.  If 'convert' is a callable the response 
		will be passed to it as an argument and the result returned.  If the 
		response matches 'expect' (default: "OK") None is returned.  Otherwise 
		an UnexpectedResponseError is raised.

		The standard failure responses are:

		FAIL:
		  The command may have had bad arguments or could not be accepted.  
		  Raises CommandFailed

		UNKNOWN COMMAND:
		  Either the command was not known or argument where supplied to 
		  a command that does not take arguments, or vice versa.
		  Raises ValueError
		"""
		if args:
			message = f"{message} {separator.join(args)}"
		msgbytes = message.encode()

		self.logger.debug("Sending: %s", repr(message))

		async with self._reply as queue:
			if len(msgbytes) != (await self.sock.send(msgbytes)):
				raise errors.MessageTooLargeError(msgbytes)

			# Continuously run _process() until the reply queue has a message
			while queue.empty():
				await self._process(queue)

			resp = await queue.get()

		if resp == consts.RESPONSE_FAIL:
			raise errors.CommandFailed(f"command returned FAIL: {message!r}")
		if resp == consts.RESPONSE_UNKNOWN_COMMAND:
			raise ValueError(f"Unknown command: {message!r}")
		if convert:
			return convert(resp)
		if resp != expect:
			raise errors.UnexpectedResponseError(
Dom Sekotill's avatar
Dom Sekotill committed
				f"Unexpected response to {message!r}: {resp!r}"
			)
		return None

	def attach(self):
		"""
		Return a context manager that handles attaching to the daemon's message queue
		"""
		return self._AttachContext(self)

	async def event(self, *events: str) -> Tuple[EventPriority, str, str]:
		"""
		Await any of the given set of events
		"""
		async with self.attach():
			with self._events_queue(events) as queue:
				while queue.empty():
					await self._process(queue)
				return await queue.get()

	async def _process(self, queue: anyio.Queue):
		async with self._lock:
Dom Sekotill's avatar
Dom Sekotill committed
			# Shortcut if the queue of interest has a message from another call
			# to _process() (probably in another coroutine)
			if not queue.empty():
				return

			msg = (await self.sock.recv(MAX_DGRAM_READ)).decode().strip()

		self.logger.debug("Received: %s", repr(msg))
		match = self.event_regex.match(msg)
		if not match:
			# If it's not an event, it must be a reply to a sent message
			if self._reply.queue:
				await self._reply.queue.put(msg)
			else:
				self.logger.warning("Unexpected response message: %s", msg)
		prio, name, msg = match.groups()
		prio = EventPriority(int(prio))

		if name is None:
			self.logger.log(prio.get_logger_level(), msg)
			return

		try:
			queues = self._eventqueues[name]
		except KeyError:
			self.logger.debug("[unhandled] %s: %s", name, msg or "[no arguments]")
		else:
			for msgqueue in queues:
				await msgqueue.put((prio, name, msg))

	@contextlib.contextmanager
	def _events_queue(self, events: Sequence[str]):
		evtqueues = self._eventqueues
		queue = anyio.create_queue(1)
		for evt in events:
			try:
				queues = evtqueues[evt]
			except KeyError:
				queues = evtqueues[evt] = set()
			queues.add(queue)
		try:
			yield queue
		finally:
			for evt in events:
				evtqueues[evt].remove(queue)

	class _AttachContext:
		def __init__(self, client):
			self.client = client

		async def __aenter__(self):
			client = self.client
			assert client._eventcount >= 0
			if client._eventcount == 0:
				await client.send_command(consts.COMMAND_ATTACH)
			client._eventcount += 1

		async def __aexit__(self, *exc_info):
			client = self.client
			assert client._eventcount > 0
			client._eventcount -= 1
			if client._eventcount == 0:
Dom Sekotill's avatar
Dom Sekotill committed
				if __debug__:  # On it's own for compiler optimisation
					if exc_info[0]:
						client.logger.debug(f"Detaching due to {exc_info[0].__name__}")
				await client.send_command(consts.COMMAND_DETACH)


class ReplyManager:
	"""
	A context manager supplying a locked reply queue
	"""

	def __init__(self):
		self.lock = anyio.create_lock()
		self.queue = None

	def __getattr__(self, name):
		return getattr(self.queue, name)

	async def __aenter__(self):
		await self.lock.__aenter__()
		self.queue = queue = anyio.create_queue(1)
		return queue

	async def __aexit__(self, *exc_info):
		self.queue, queue = None, self.queue
		await self.lock.__aexit__(*exc_info)
		assert queue.empty(), "Reply queue was not processed"