Skip to content
base.py 7.13 KiB
Newer Older
#  Copyright 2019  Dom Sekotill <dom.sekotill@kodo.org.uk>
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

"""
This module provides a base WPA-Supplicant client implementation
"""

import contextlib
import enum
import logging
import pathlib
import os
from re import compile as regex
from typing import Any, Callable, Optional, Set, Sequence, Tuple, Union

import anyio

from . import consts
from .. import errors, util


# 128kB (actual max size slightly less than this)
MAX_DGRAM_READ = 2 ** 17


class EventPriority(enum.IntEnum):
	"""
	Event Message priorities
	"""

	def get_logger_level(self, *, _mapping={}):
		"""
		Return a logging level matching the `wpa_supplicant` priority level
		"""
		if not _mapping:
			_mapping.update({
				self.MSGDUMP: logging.DEBUG,
				self.DEBUG: logging.DEBUG,
				self.INFO: logging.INFO,
				self.NOTICE: logging.INFO,
				self.WARNING: logging.WARNING,
				self.ERROR: logging.ERROR,
			})
		return _mapping[self]

	MSGDUMP = 0
	DEBUG = 1
	INFO = 2
	NOTICE = 3
	WARNING = 4
	ERROR = 5


class BaseClient:
	"""
	A client for controlling a WPA-Supplicant daemon over a control socket

	This class is a naïve implementation. You probably want MasterClient and 
	InterfaceClient.
	"""

	event_regex = regex(r'<([0-9]+)>(\S+)(?:\s(.*))?')

	def __init__(self, *, logger=None):
		self.logger = logger or logging.getLogger(__package__)
		self.ctrl_dir = None
		self.sock = None
		self._lock = anyio.create_lock()
		self._reply = ReplyManager()
		self._eventqueues = dict()
		self._eventcount = 0

	async def __aenter__(self):
		return self

	async def __aexit__(self, *exc_info):
		await self.disconnect()

	async def connect(self, path: consts.Path):
		"""
		Connect to a WPA-Supplicant daemon through the given address
		"""
		if self.sock is not None:
			raise RuntimeError(f"cannot connect to multiple daemons")

		if not isinstance(path, pathlib.Path):
			path = pathlib.Path(os.fspath(path))

		async with anyio.fail_after(1):
			self.sock = await util.connect_unix_datagram(path.as_posix())
			await self.send_command(consts.COMMAND_PING, expect=consts.RESPONSE_PONG)

	async def disconnect(self):
		"""
		Disconnect from the connected daemon, if connected
		"""
		if self.sock:
			await self.sock.close()

	async def send_command(self,
			message: str,
			*args: str,
			seperator: str = '\t',
			expect: str = consts.RESPONSE_OK,
			convert: Optional[Callable] = None,
			) -> Any:
		"""
		Send a message and await a response

		If one of the failure responses described below are returned from the 
		daemon an exception is raised.  If 'convert' is a callable the response 
		will be passed to it as an argument and the result returned.  If the 
		response matches 'expect' (default: "OK") None is returned.  Otherwise 
		an UnexpectedResponseError is raised.

		The standard failure responses are:

		FAIL:
		  The command may have had bad arguments or could not be accepted.  
		  Raises CommandFailed

		UNKNOWN COMMAND:
		  Either the command was not known or argument where supplied to 
		  a command that does not take arguments, or vice versa.
		  Raises ValueError
		"""
		if args:
			message = f'{message} {seperator.join(args)}'
		msgbytes = message.encode()

		self.logger.debug("Sending: %s", repr(message))

		async with self._reply as queue:
			if len(msgbytes) != (await self.sock.send(msgbytes)):
				raise errors.MessageTooLargeError(msgbytes)

			# Continuously run _process() until the reply queue has a message
			while queue.empty():
				await self._process(queue)

			resp = await queue.get()

		if resp == consts.RESPONSE_FAIL:
			raise errors.CommandFailed(f"command returned FAIL: {message!r}")
		if resp == consts.RESPONSE_UNKNOWN_COMMAND:
			raise ValueError(f"Unknown command: {message!r}")
		if convert:
			return convert(resp)
		if resp != expect:
			raise errors.UnexpectedResponseError(
					f"Unexpected response to {message!r}: {resp!r}")
		return None

	def attach(self):
		"""
		Return a context manager that handles attaching to the daemon's message queue
		"""
		return self._AttachContext(self)

	async def event(self, *events: str) -> Tuple[EventPriority, str, str]:
		"""
		Await any of the given set of events
		"""
		async with self.attach():
			with self._events_queue(events) as queue:
				while queue.empty():
					await self._process(queue)
				return await queue.get()

	async def _process(self, queue: anyio.Queue):
		async with self._lock:
			# Shortcut if the queue of interest has a message from another call 
			# to _process() (probably in another coroutine)
			if not queue.empty():
				return

			msg = (await self.sock.recv(MAX_DGRAM_READ)).decode().strip()

		self.logger.debug("Received: %s", repr(msg))
		match = self.event_regex.match(msg)
		if not match:
			# If it's not an event, it must be a reply to a sent message
			if self._reply.queue:
				await self._reply.queue.put(msg)
			else:
				self.logger.warning("Unexpected response message: %s", msg)
			return

		prio, name, args = match.groups()
		prio = EventPriority(int(prio))

		try:
			queues = self._eventqueues[name]
		except KeyError:
			self.logger.log(prio.get_logger_level(), "%s [UNEXPECTED]: %s", name, args)
		else:
			for msgqueue in queues:
				await msgqueue.put((prio, name, args))

	@contextlib.contextmanager
	def _events_queue(self, events: Union[Sequence, Set]):
		evtqueues = self._eventqueues
		queue = anyio.create_queue(1)
		for evt in events:
			try:
				queues = evtqueues[evt]
			except KeyError:
				queues = evtqueues[evt] = set()
			queues.add(queue)
		try:
			yield queue
		finally:
			for evt in events:
				evtqueues[evt].remove(queue)

	class _AttachContext:

		def __init__(self, client):
			self.client = client

		async def __aenter__(self):
			client = self.client
			assert client._eventcount >= 0
			if client._eventcount == 0:
				await client.send_command(consts.COMMAND_ATTACH)
			client._eventcount += 1

		async def __aexit__(self, *exc_info):
			client = self.client
			assert client._eventcount > 0
			client._eventcount -= 1
			if client._eventcount == 0:
				if __debug__: # On it's own for compiler optimisation
					if exc_info[0]:
						client.logger.debug(f"Detaching due to {exc_info[0].__name__}")
				await client.send_command(consts.COMMAND_DETACH)


class ReplyManager:
	"""
	A context manager supplying a locked reply queue
	"""

	def __init__(self):
		self.lock = anyio.create_lock()
		self.queue = None

	def __getattr__(self, name):
		return getattr(self.queue, name)

	async def __aenter__(self):
		await self.lock.__aenter__()
		self.queue = queue = anyio.create_queue(1)
		return queue

	async def __aexit__(self, *exc_info):
		self.queue, queue = None, self.queue
		await self.lock.__aexit__(*exc_info)
		assert queue.empty(), "Reply queue was not processed"