Newer
Older
# Copyright 2019-2021 Dom Sekotill <dom.sekotill@kodo.org.uk>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Test cases for wpa_supplicant.client.base.BaseClient
"""
import unittest
from unittest import mock
import anyio
from tests import _anyio as anyio_mock
from wpa_supplicant import errors
from wpa_supplicant.client import base
"wpa_supplicant.client.base.connect_unix_datagram",
new_callable=anyio_mock.AsyncMock,
)
@mock.patch(
"wpa_supplicant.client.base.BaseClient.send_command",
new_callable=anyio_mock.AsyncMock,
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class ConnectTests(unittest.TestCase):
"""
Tests for the connect() method
"""
@anyio_mock.with_anyio()
async def test_connect(self, _, connect_mock):
"""
Check connect() calls socket.connect()
"""
async with base.BaseClient() as client:
await client.connect("foo")
connect_mock.assert_called_once_with("foo")
@anyio_mock.with_anyio()
async def test_connect_timeout_1(self, _, connect_mock):
"""
Check a socket.connect() delay causes TimeoutError to be raised
"""
connect_mock.delay = 2
async with base.BaseClient() as client:
with self.assertRaises(TimeoutError):
await client.connect("foo")
@anyio_mock.with_anyio()
async def test_connect_timeout_2(self, send_mock, _):
"""
Check a send/recv delay causes a TimeoutError to be raised
"""
send_mock.delay = 2
async with base.BaseClient() as client:
with self.assertRaises(TimeoutError):
await client.connect("foo")
class SendMessageTests(unittest.TestCase):
"""
Tests for the send_command() method
self.client = client = base.BaseClient()
client.sock = anyio_mock.AsyncMock(spec=anyio.abc.SocketStream)
client.sock.send.return_value = None
assert isinstance(client.sock, anyio.abc.SocketStream)
@anyio_mock.with_anyio()
async def test_simple(self):
"""
Check that a response is processed after a command
"""
async with self.client as client:
assert await client.send_command("SOME_COMMAND") is None
@anyio_mock.with_anyio()
async def test_simple_expect(self):
"""
Check that an alternate expected response is processed
"""
async with self.client as client:
client.sock.receive.return_value = b"PONG"
assert await client.send_command("PING", expect="PONG") is None
@anyio_mock.with_anyio()
async def test_simple_no_expect(self):
"""
Check that an unexpected response raises an UnexpectedResponseError
"""
async with self.client as client:
client.sock.receive.return_value = b"DING"
with self.assertRaises(errors.UnexpectedResponseError):
await client.send_command("PING")
with self.assertRaises(errors.UnexpectedResponseError):
await client.send_command("PING", expect="PONG")
@anyio_mock.with_anyio()
async def test_simple_convert(self):
"""
Check that a response is passed through a converter if given
"""
async with self.client as client:
client.sock.receive.return_value = b"FOO\nBAR\nBAZ\n"
await client.send_command(
"SOME_COMMAND", convert=lambda x: x.splitlines(),
)
@anyio_mock.with_anyio()
async def test_simple_convert_over_expect(self):
"""
Check that 'convert' overrides 'expect'
"""
async with self.client as client:
client.sock.receive.return_value = b"FOO\nBAR\nBAZ\n"
await client.send_command(
"SOME_COMMAND", convert=lambda x: x.splitlines(), expect="PONG",
)
@anyio_mock.with_anyio()
async def test_simple_fail(self):
"""
Check that a response of 'FAIL' causes CommandFailed to be raised
"""
async with self.client as client:
client.sock.receive.return_value = b"FAIL"
with self.assertRaises(errors.CommandFailed):
await client.send_command("SOME_COMMAND")
@anyio_mock.with_anyio()
async def test_simple_bad_command(self):
"""
Check that a response of 'UNKNOWN COMMAND' causes ValueError to be raised
"""
async with self.client as client:
client.sock.receive.return_value = b"UNKNOWN COMMAND"
await client.send_command("SOME_COMMAND")
@anyio_mock.with_anyio()
async def test_interleaved(self):
"""
Check that messages are processed alongside replies
"""
async with self.client as client:
b"<2>SOME-MESSAGE",
b"<1>SOME-OTHER-MESSAGE with|args",
b"OK",
assert await client.send_command("SOME_COMMAND") is None
@anyio_mock.with_anyio()
async def test_unexpected(self):
"""
Check that unexpected replies are logged cleanly
"""
async with self.client as client:
b"OK", # Response to "ATTACH"
b"UNEXPECTED1",
b"UNEXPECTED2",
b"OK", # Response to "DETACH"
]
assert await client.event("CTRL-EVENT-EXAMPLE")
@anyio_mock.with_anyio()
async def test_unconnected(self):
"""
Check that calling send_command() on an unconnected client raises RuntimeError
"""
client = base.BaseClient()
with self.assertRaises(RuntimeError):
await client.send_command("SOME_COMMAND")
@anyio_mock.with_anyio()
async def test_multi_task(self):
"""
Check that calling send_command() from multiple tasks works as expected
"""
recv_responses = iter([
(0.0, b"OK"), # Response to ATTACH
(0.5, b"OK"), # Response to SOME_COMMAND1
(0.2, b"<2>CTRL-FOO"), # Event
(0.1, b"REPLY2"), # Response to SOME_COMMAND2
(0.0, b"OK"), # Response to DETACH
])
async def recv():
delay, data = next(recv_responses)
await anyio.sleep(delay)
return data
async with self.client as client, anyio.create_task_group() as task_group:
client.sock.receive.side_effect = recv
@task_group.start_soon
async def wait_for_event():
self.assertTupleEqual(
await client.event("CTRL-FOO"),
(base.EventPriority.INFO, "CTRL-FOO", None),
)
await anyio.sleep(0.1) # Ensure send_command("ATTACH") has been sent
task_group.start_soon(client.send_command, "SOME_COMMAND1")
await anyio.sleep(0.1) # Ensure send_command("SOME_COMMAND1") has been sent
# At this point the response to SOME_COMMAND1 is still delayed
await client.send_command("SOME_COMMAND2", expect="REPLY2")
@anyio_mock.with_anyio()
async def test_multi_task_decode_error(self):
"""
Check that decode errors closes the socket and causes all tasks to raise EOFError
"""
recv_responses = [
b"OK", # Response to ATTACH
b"\xa5\x8b", # Undecodable input
anyio.EndOfStream,
anyio.EndOfStream,
]
async with self.client as client, anyio.create_task_group() as task_group:
client.sock.receive.side_effect = recv_responses
@task_group.start_soon
async def wait_for_event():
with self.assertRaises(anyio.ClosedResourceError):
await client.event("CTRL-FOO"),
await anyio.sleep(0.1) # Ensure send_command("ATTACH") has been sent
with self.assertRaises(anyio.ClosedResourceError):
await client.send_command("SOME_COMMAND", expect="REPLY")
class EventTests(unittest.TestCase):
"""
Tests for the event() method
"""
def setUp(self):
@anyio_mock.with_anyio()
async def test_simple(self):
"""
Check that an awaited message is returned when is arrives
"""
with anyio.fail_after(2):
async with self.client as client:
client.sock.receive.side_effect = [
b"OK", # Respond to ATTACH
b"<2>CTRL-EVENT-EXAMPLE",
b"OK", # Respond to DETACH
]
prio, evt, args = await client.event("CTRL-EVENT-EXAMPLE")
assert prio == 2
assert evt == "CTRL-EVENT-EXAMPLE"
assert args is None
@anyio_mock.with_anyio()
async def test_multiple(self):
"""
Check that an awaited messages is returned when it arrives between others
"""
with anyio.fail_after(2):
async with self.client as client:
client.sock.receive.side_effect = [
b"OK", # Respond to ATTACH
b"<1>OTHER-MESSAGE",
b"<2>CTRL-EVENT-OTHER",
b"<4>CTRL-EVENT-EXAMPLE",
b"OK", # Respond to DETACH
b"<3>OTHER-MESSAGE",
]
prio, evt, args = await client.event("CTRL-EVENT-EXAMPLE")
assert prio == 4
assert evt == "CTRL-EVENT-EXAMPLE"
assert args is None
@anyio_mock.with_anyio()
async def test_wait_multiple(self):
"""
Check that the first of several awaited events is returned
"""
with anyio.fail_after(2):
async with self.client as client:
client.sock.receive.side_effect = [
b"OK", # Respond to ATTACH
b"<1>OTHER-MESSAGE",
b"<2>CTRL-EVENT-OTHER",
b"<4>CTRL-EVENT-EXAMPLE3",
b"<4>CTRL-EVENT-EXAMPLE1",
b"OK", # Respond to DETACH
b"<3>CTRL-EVENT-OTHER",
]
prio, evt, args = await client.event(
"CTRL-EVENT-EXAMPLE1", "CTRL-EVENT-EXAMPLE2", "CTRL-EVENT-EXAMPLE3",
)
assert prio == 4
assert evt == "CTRL-EVENT-EXAMPLE3"
assert args is None
@anyio_mock.with_anyio()
async def test_interleaved(self):
"""
Check that messages are processed as well as replies
"""
with anyio.fail_after(2):
async with self.client as client:
client.sock.receive.side_effect = [
b"<1>OTHER-MESSAGE",
b"OK", # Respond to SOME_COMMAND
b"OK", # Respond to ATTACH
b"<2>CTRL-EVENT-OTHER",
b"<4>CTRL-EVENT-EXAMPLE",
b"<3>CTRL-EVENT-OTHER",
b"OK", # Respond to DETACH
b"FOO",
]
assert await client.send_command("SOME_COMMAND") is None
prio, evt, args = await client.event("CTRL-EVENT-EXAMPLE")
assert prio == 4
assert evt == "CTRL-EVENT-EXAMPLE"
assert args is None
assert await client.send_command("SOME_COMMAND", expect="FOO") is None
@anyio_mock.with_anyio()
async def test_unconnected(self):
"""
Check that calling event() on an unconnected client raises RuntimeError
"""
client = base.BaseClient()
with self.assertRaises(RuntimeError):
await client.event("some", "events")