Skip to content
_anyio.py 3.7 KiB
Newer Older
#  Copyright 2021  Dom Sekotill <dom.sekotill@kodo.org.uk>
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

"""
Work-arounds for lack of AF_UNIX datagram socket support in Anyio
"""

import abc
import errno
import os
import socket
import tempfile
from contextlib import suppress
from typing import cast

import anyio.abc

try:
	from anyio import _get_asynclib
except ImportError:
	from anyio._core._eventloop import get_asynclib as _get_asynclib

from .types import PathLike


class ConnectedUNIXAbstract(abc.ABC):

	_raw_socket: socket.socket

	@abc.abstractmethod
	async def aclose(self) -> None:
		pass


class ConnectedUNIXMixin(ConnectedUNIXAbstract):

	async def aclose(self) -> None:
		path = self._raw_socket.getsockname()
		await super().aclose()
		os.unlink(path)


async def connect_unix_datagram(path: PathLike) -> anyio.abc.SocketStream:
	"""
	Return an AnyIO socket connected to a Unix datagram socket

	This behaviour is currently missing from AnyIO.
	"""
	for _ in range(10):
		fname = tempfile.mktemp(suffix=".sock", prefix="wpa_ctrl.")
		with suppress(FileExistsError):
			return await _get_asynclib().connect_unix_datagram(
				local_path=fname,
				remote_path=path,
			)
	raise FileExistsError(
		errno.EEXIST, "No usable temporary filename found",
	)


try:
	import trio
except ImportError: ...
else:
	from anyio._backends import _trio

	class TrioConnectedUNIXSocket(ConnectedUNIXMixin, _trio.ConnectedUDPSocket):  # type: ignore
		...

	async def trio_connect_unix_datagram(
		local_path: PathLike,
		remote_path: PathLike,
	) -> TrioConnectedUNIXSocket:
		sock = trio.socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
		await sock.bind(os.fspath(local_path))
		try:
			await sock.connect(os.fspath(remote_path))
		except BaseException:  # pragma: no cover
			sock.close()
			raise
		else:
			return TrioConnectedUNIXSocket(sock)

	_trio.connect_unix_datagram = trio_connect_unix_datagram


# asyncio is in the stdlib, but lets make the layout match trio 😉
try:
	import asyncio
except ImportError: ...
else:
	from anyio._backends import _asyncio

	class AsyncioConnectedUNIXSocket(ConnectedUNIXMixin, _asyncio.ConnectedUDPSocket):  # type: ignore
		...

	async def asyncio_connect_unix_datagram(
		local_path: PathLike,
		remote_path: PathLike,
	) -> AsyncioConnectedUNIXSocket:
		await asyncio.sleep(0.0)
		loop = asyncio.get_running_loop()
		sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
		sock.setblocking(False)
		sock.bind(os.fspath(local_path))
		while True:
			try:
				sock.connect(os.fspath(remote_path))
			except BlockingIOError:
				future: asyncio.Future[None] = asyncio.Future()
				loop.add_writer(sock, future.set_result, None)
				future.add_done_callback(lambda _: loop.remove_writer(sock))
				await future
			except BaseException:
				sock.close()
				raise
			else:
				break

		transport_, protocol_ = await asyncio.get_running_loop().create_datagram_endpoint(
			_asyncio.DatagramProtocol,
			sock=sock,
		)
		transport = cast(asyncio.DatagramTransport, transport_)
		protocol = cast(_asyncio.DatagramProtocol, protocol_)
		if protocol.exception:
			transport.close()
			raise protocol.exception
		return AsyncioConnectedUNIXSocket(transport, protocol)

	_asyncio.connect_unix_datagram = asyncio_connect_unix_datagram