From ba05d38602f7fad2987df9fe407ee68ea03d69e6 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 1 Nov 2024 11:46:10 -0500 Subject: [PATCH] Implement zerocopy writes (#990) --- aioesphomeapi/_frame_helper/base.pxd | 2 +- aioesphomeapi/_frame_helper/base.py | 21 ++++++++------ aioesphomeapi/_frame_helper/noise.py | 4 +-- aioesphomeapi/_frame_helper/plain_text.py | 5 ++-- tests/common.py | 18 +++++++----- tests/test__frame_helper.py | 23 +++++++-------- tests/test_client.py | 25 ++++++++++++----- tests/test_connection.py | 34 ++++++++++++----------- 8 files changed, 78 insertions(+), 54 deletions(-) diff --git a/aioesphomeapi/_frame_helper/base.pxd b/aioesphomeapi/_frame_helper/base.pxd index 187be7af..72638b0a 100644 --- a/aioesphomeapi/_frame_helper/base.pxd +++ b/aioesphomeapi/_frame_helper/base.pxd @@ -11,7 +11,7 @@ cdef class APIFrameHelper: cdef object _loop cdef APIConnection _connection cdef object _transport - cdef public object _writer + cdef public object _writelines cdef public object ready_future cdef bytes _buffer cdef unsigned int _buffer_len diff --git a/aioesphomeapi/_frame_helper/base.py b/aioesphomeapi/_frame_helper/base.py index 940c61a3..f73365ca 100644 --- a/aioesphomeapi/_frame_helper/base.py +++ b/aioesphomeapi/_frame_helper/base.py @@ -2,6 +2,7 @@ from abc import abstractmethod import asyncio +from collections.abc import Iterable import logging from typing import TYPE_CHECKING, Callable, cast @@ -31,7 +32,7 @@ class APIFrameHelper: "_loop", "_connection", "_transport", - "_writer", + "_writelines", "ready_future", "_buffer", "_buffer_len", @@ -51,7 +52,9 @@ def __init__( self._loop = loop self._connection = connection self._transport: asyncio.Transport | None = None - self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None + self._writelines: ( + None | (Callable[[Iterable[bytes | bytearray | memoryview[int]]], None]) + ) = None self.ready_future = self._loop.create_future() self._buffer: bytes | None = None self._buffer_len = 0 @@ -146,7 +149,7 @@ def write_packets( def connection_made(self, transport: asyncio.BaseTransport) -> None: """Handle a new connection.""" self._transport = cast(asyncio.Transport, transport) - self._writer = self._transport.write + self._writelines = self._transport.writelines def _handle_error_and_close(self, exc: Exception) -> None: self._handle_error(exc) @@ -172,7 +175,7 @@ def close(self) -> None: if self._transport: self._transport.close() self._transport = None - self._writer = None + self._writelines = None def pause_writing(self) -> None: """Stub.""" @@ -180,12 +183,14 @@ def pause_writing(self) -> None: def resume_writing(self) -> None: """Stub.""" - def _write_bytes(self, data: _bytes, debug_enabled: bool) -> None: + def _write_bytes(self, data: Iterable[_bytes], debug_enabled: bool) -> None: """Write bytes to the socket.""" if debug_enabled: - _LOGGER.debug("%s: Sending frame: [%s]", self._log_name, data.hex()) + _LOGGER.debug( + "%s: Sending frame: [%s]", self._log_name, b"".join(data).hex() + ) if TYPE_CHECKING: - assert self._writer is not None, "Writer is not set" + assert self._writelines is not None, "Writer is not set" - self._writer(data) + self._writelines(data) diff --git a/aioesphomeapi/_frame_helper/noise.py b/aioesphomeapi/_frame_helper/noise.py index ee0623ea..ffcfd2ff 100644 --- a/aioesphomeapi/_frame_helper/noise.py +++ b/aioesphomeapi/_frame_helper/noise.py @@ -218,7 +218,7 @@ def _send_hello_handshake(self) -> None: frame_len = len(handshake_frame) + 1 header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF)) self._write_bytes( - b"".join((NOISE_HELLO, header, b"\x00", handshake_frame)), + (NOISE_HELLO, header, b"\x00", handshake_frame), _LOGGER.isEnabledFor(logging.DEBUG), ) @@ -346,7 +346,7 @@ def write_packets( out.append(header) out.append(frame) - self._write_bytes(b"".join(out), debug_enabled) + self._write_bytes(out, debug_enabled) def _handle_frame(self, frame: bytes) -> None: """Handle an incoming frame.""" diff --git a/aioesphomeapi/_frame_helper/plain_text.py b/aioesphomeapi/_frame_helper/plain_text.py index 8d8cc9ed..d8a593b8 100644 --- a/aioesphomeapi/_frame_helper/plain_text.py +++ b/aioesphomeapi/_frame_helper/plain_text.py @@ -57,9 +57,10 @@ def write_packets( out.append(b"\0") out.append(varuint_to_bytes(len(data))) out.append(varuint_to_bytes(type_)) - out.append(data) + if data: + out.append(data) - self._write_bytes(b"".join(out), debug_enabled) + self._write_bytes(out, debug_enabled) def data_received(self, data: bytes | bytearray | memoryview) -> None: self._add_to_buffer(data) diff --git a/tests/common.py b/tests/common.py index 95af5dbc..5408bd2a 100644 --- a/tests/common.py +++ b/tests/common.py @@ -65,15 +65,19 @@ class Estr(str): """A subclassed string.""" -def generate_plaintext_packet(msg: message.Message) -> bytes: +def generate_split_plaintext_packet(msg: message.Message) -> list[bytes]: type_ = PROTO_TO_MESSAGE_TYPE[msg.__class__] bytes_ = msg.SerializeToString() - return ( - b"\0" - + _cached_varuint_to_bytes(len(bytes_)) - + _cached_varuint_to_bytes(type_) - + bytes_ - ) + return [ + b"\0", + _cached_varuint_to_bytes(len(bytes_)), + _cached_varuint_to_bytes(type_), + bytes_, + ] + + +def generate_plaintext_packet(msg: message.Message) -> bytes: + return b"".join(generate_split_plaintext_packet(msg)) def as_utc(dattim: datetime) -> datetime: diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index b91c642c..382fd7aa 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -2,6 +2,7 @@ import asyncio import base64 +from collections.abc import Iterable import sys from typing import Any from unittest.mock import AsyncMock, MagicMock, patch @@ -132,7 +133,7 @@ def __init__(self, *args: Any, writer: Any | None = None, **kwargs: Any) -> None """Swallow args.""" super().__init__(*args, **kwargs) transport = MagicMock() - transport.write = writer or MagicMock() + transport.writelines = writer or MagicMock() self.__transport = transport self.connection_made(transport) @@ -147,7 +148,7 @@ def mock_write_frame(self, frame: bytes) -> None: frame_len = len(frame) header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF)) try: - self._writer(header + frame) + self._writelines([header, frame]) except (RuntimeError, ConnectionResetError, OSError) as err: raise SocketClosedAPIError( f"{self._log_name}: Error while writing data: {err}" @@ -437,8 +438,8 @@ async def test_noise_frame_helper_handshake_failure(): psk_bytes = base64.b64decode(noise_psk) writes = [] - def _writer(data: bytes): - writes.append(data) + def _writelines(data: Iterable[bytes]): + writes.append(b"".join(data)) connection, _ = _make_mock_connection() @@ -448,7 +449,7 @@ def _writer(data: bytes): expected_name="servicetest", client_info="my client", log_name="test", - writer=_writer, + writer=_writelines, ) proto = _mock_responder_proto(psk_bytes) @@ -486,8 +487,8 @@ async def test_noise_frame_helper_handshake_success_with_single_packet(): psk_bytes = base64.b64decode(noise_psk) writes = [] - def _writer(data: bytes): - writes.append(data) + def _writelines(data: Iterable[bytes]): + writes.append(b"".join(data)) connection, packets = _make_mock_connection() @@ -497,7 +498,7 @@ def _writer(data: bytes): expected_name="servicetest", client_info="my client", log_name="test", - writer=_writer, + writer=_writelines, ) proto = _mock_responder_proto(psk_bytes) @@ -548,8 +549,8 @@ async def test_noise_frame_helper_bad_encryption( psk_bytes = base64.b64decode(noise_psk) writes = [] - def _writer(data: bytes): - writes.append(data) + def _writelines(data: Iterable[bytes]): + writes.append(b"".join(data)) connection, packets = _make_mock_connection() @@ -559,7 +560,7 @@ def _writer(data: bytes): expected_name="servicetest", client_info="my client", log_name="test", - writer=_writer, + writer=_writelines, ) proto = _mock_responder_proto(psk_bytes) diff --git a/tests/test_client.py b/tests/test_client.py index cd5e2dfb..bba93564 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -126,6 +126,7 @@ from .common import ( Estr, generate_plaintext_packet, + generate_split_plaintext_packet, get_mock_zeroconf, mock_data_received, ) @@ -1439,7 +1440,12 @@ async def test_bluetooth_gatt_write_without_response( ) await asyncio.sleep(0) await write_task - assert transport.mock_calls[0][1][0] == b'\x00\x0cK\x08\xd2\t\x10\xd2\t"\x041234' + assert transport.mock_calls[0][1][0] == [ + b"\x00", + b"\x0c", + b"K", + b'\x08\xd2\t\x10\xd2\t"\x041234', + ] with pytest.raises(TimeoutAPIError, match="BluetoothGATTWriteResponse"): await client.bluetooth_gatt_write(1234, 1234, b"1234", True, timeout=0) @@ -1484,7 +1490,12 @@ async def test_bluetooth_gatt_write_descriptor_without_response( ) await asyncio.sleep(0) await write_task - assert transport.mock_calls[0][1][0] == b"\x00\x0cM\x08\xd2\t\x10\xd2\t\x1a\x041234" + assert transport.mock_calls[0][1][0] == [ + b"\x00", + b"\x0c", + b"M", + b"\x08\xd2\t\x10\xd2\t\x1a\x041234", + ] with pytest.raises(TimeoutAPIError, match="BluetoothGATTWriteResponse"): await client.bluetooth_gatt_write_descriptor(1234, 1234, b"1234", timeout=0) @@ -2042,8 +2053,8 @@ def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None cancel = await connect_task assert states == [(True, 23, 0)] - transport.write.assert_called_once_with( - generate_plaintext_packet( + transport.writelines.assert_called_once_with( + generate_split_plaintext_packet( BluetoothDeviceRequest( address=1234, request_type=method, @@ -2133,13 +2144,13 @@ def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None ) await asyncio.sleep(0) # The connect request should be written - assert len(transport.write.mock_calls) == 1 + assert len(transport.writelines.mock_calls) == 1 await asyncio.sleep(0) await asyncio.sleep(0) await asyncio.sleep(0) # Now that we timed out, the disconnect # request should be written - assert len(transport.write.mock_calls) == 2 + assert len(transport.writelines.mock_calls) == 2 response: message.Message = BluetoothDeviceConnectionResponse( address=1234, connected=False, mtu=23, error=8 ) @@ -2177,7 +2188,7 @@ def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None ) await asyncio.sleep(0) # The connect request should be written - assert len(transport.write.mock_calls) == 1 + assert len(transport.writelines.mock_calls) == 1 connect_task.cancel() with pytest.raises(asyncio.CancelledError): await connect_task diff --git a/tests/test_connection.py b/tests/test_connection.py index 260b3f48..bd9778ab 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -115,7 +115,7 @@ async def test_timeout_sending_message( with patch("aioesphomeapi.connection.DISCONNECT_RESPONSE_TIMEOUT", 0.0): await conn.disconnect() - transport.write.assert_called_with(b"\x00\x00\x05") + transport.writelines.assert_called_with([b"\x00", b"\x00", b"\x05"]) assert "disconnect request failed" in caplog.text assert " Timeout waiting for DisconnectResponse after 0.0s" in caplog.text @@ -152,7 +152,7 @@ async def test_disconnect_when_not_fully_connected( ): await connect_task - transport.write.assert_called_with(b"\x00\x00\x05") + transport.writelines.assert_called_with([b"\x00", b"\x00", b"\x05"]) assert "disconnect request failed" in caplog.text assert " Timeout waiting for DisconnectResponse after 0.0s" in caplog.text @@ -506,7 +506,7 @@ def _create_failing_mock_transport_protocol( ) -> tuple[asyncio.Transport, APIPlaintextFrameHelperHandshakeException]: protocol: APIPlaintextFrameHelperHandshakeException = create_func() protocol._transport = cast(asyncio.Transport, transport) - protocol._writer = transport.write + protocol._writelines = transport.writelines protocol.ready_future.set_exception(exception) connected.set() return transport, protocol @@ -549,7 +549,9 @@ async def _do_finish_connect(self, *args, **kwargs): connect_task = asyncio.create_task(connect(conn, login=False)) await connected.wait() - with (pytest.raises(raised_exception),): + with ( + pytest.raises(raised_exception), + ): await asyncio.sleep(0) await connect_task @@ -646,7 +648,7 @@ async def test_force_disconnect_fails( await connect_task assert conn.is_connected - with patch.object(protocol, "_writer", side_effect=OSError): + with patch.object(protocol, "_writelines", side_effect=OSError): conn.force_disconnect() assert "Failed to send (forced) disconnect request" in caplog.text await asyncio.sleep(0) @@ -822,7 +824,7 @@ async def _on_stop(_expected_disconnect: bool) -> None: await connect_task assert client._connection.is_connected - with patch.object(protocol, "_writer", side_effect=OSError): + with patch.object(protocol, "_writelines", side_effect=OSError): disconnect_request = DisconnectRequest() mock_data_received(protocol, generate_plaintext_packet(disconnect_request)) @@ -893,7 +895,7 @@ async def test_ping_disconnects_after_no_responses( await connect_task - ping_request_bytes = b"\x00\x00\x07" + ping_request_bytes = [b"\x00", b"\x00", b"\x07"] assert conn.is_connected transport.reset_mock() @@ -904,9 +906,9 @@ async def test_ping_disconnects_after_no_responses( async_fire_time_changed( start_time + timedelta(seconds=KEEP_ALIVE_INTERVAL * count) ) - assert transport.write.call_count == count + assert transport.writelines.call_count == count expected_calls.append(call(ping_request_bytes)) - assert transport.write.mock_calls == expected_calls + assert transport.writelines.mock_calls == expected_calls assert conn.is_connected is True @@ -915,7 +917,7 @@ async def test_ping_disconnects_after_no_responses( start_time + timedelta(seconds=KEEP_ALIVE_INTERVAL * (max_pings_to_disconnect_after + 1)) ) - assert transport.write.call_count == max_pings_to_disconnect_after + 1 + assert transport.writelines.call_count == max_pings_to_disconnect_after + 1 assert conn.is_connected is False @@ -932,7 +934,7 @@ async def test_ping_does_not_disconnect_if_we_get_responses( send_plaintext_connect_response(protocol, False) await connect_task - ping_request_bytes = b"\x00\x00\x07" + ping_request_bytes = [b"\x00", b"\x00", b"\x07"] assert conn.is_connected transport.reset_mock() @@ -945,8 +947,8 @@ async def test_ping_does_not_disconnect_if_we_get_responses( send_ping_response(protocol) # We should only send 1 ping request if we are getting responses - assert transport.write.call_count == 1 - assert transport.write.mock_calls == [call(ping_request_bytes)] + assert transport.writelines.call_count == 1 + assert transport.writelines.mock_calls == [call(ping_request_bytes)] # We should disconnect if we are getting ping responses assert conn.is_connected is True @@ -976,9 +978,9 @@ async def test_respond_to_ping_request( transport.reset_mock() send_ping_request(protocol) # We should respond to ping requests - ping_response_bytes = b"\x00\x00\x08" - assert transport.write.call_count == 1 - assert transport.write.mock_calls == [call(ping_response_bytes)] + ping_response_bytes = [b"\x00", b"\x00", b"\x08"] + assert transport.writelines.call_count == 1 + assert transport.writelines.mock_calls == [call(ping_response_bytes)] @pytest.mark.asyncio