From 4e7a466f899d4a2dc92362a0f86d60dee5b22758 Mon Sep 17 00:00:00 2001 From: John Belmonte Date: Mon, 26 Oct 2020 11:11:21 +0900 Subject: [PATCH] open_memory_channel(): return a named tuple partially addresses #719 --- docs/source/reference-core.rst | 6 ++++++ newsfragments/1771.feature.rst | 5 +++++ trio/_channel.py | 17 ++++++++++++----- trio/tests/test_channel.py | 5 +++++ trio/tests/test_highlevel_serve_listeners.py | 4 ++-- 5 files changed, 30 insertions(+), 7 deletions(-) create mode 100644 newsfragments/1771.feature.rst diff --git a/docs/source/reference-core.rst b/docs/source/reference-core.rst index 664a8d96c2..b67bf52e78 100644 --- a/docs/source/reference-core.rst +++ b/docs/source/reference-core.rst @@ -1124,6 +1124,12 @@ inside a single process, and for that you can use .. autofunction:: open_memory_channel(max_buffer_size) +Assigning the send and receive channels to separate variables usually +produces the most readable code. However, in situations where the pair +is preserved-- such as a collection of memory channels-- prefer named tuple +access (``pair.send_channel``, ``pair.receive_channel``) over indexed access +(``pair[0]``, ``pair[1]``). + .. note:: If you've used the :mod:`threading` or :mod:`asyncio` modules, you may be familiar with :class:`queue.Queue` or :class:`asyncio.Queue`. In Trio, :func:`open_memory_channel` is diff --git a/newsfragments/1771.feature.rst b/newsfragments/1771.feature.rst new file mode 100644 index 0000000000..a2e861b7d6 --- /dev/null +++ b/newsfragments/1771.feature.rst @@ -0,0 +1,5 @@ +open_memory_channel() now returns a named tuple with attributes ``send_channel`` +and ```receive_channel`. This can be used to avoid indexed access of the +channel halves in some scenarios such as a collection of channels. (Note: when +dealing with a single memory channel, assigning the send and receive halves +to separate variables via destructuring is still considered more readable.) diff --git a/trio/_channel.py b/trio/_channel.py index dac7935c0c..417e530024 100644 --- a/trio/_channel.py +++ b/trio/_channel.py @@ -1,16 +1,24 @@ from collections import deque, OrderedDict from math import inf +from typing import NamedTuple import attr from outcome import Error, Value -from .abc import SendChannel, ReceiveChannel, Channel +from .abc import SendChannel, ReceiveChannel from ._util import generic_function, NoPublicConstructor import trio from ._core import enable_ki_protection +class MemoryChannelPair(NamedTuple): + """Named tuple of send/receive memory channels""" + + send_channel: "MemorySendChannel" + receive_channel: "MemoryReceiveChannel" + + @generic_function def open_memory_channel(max_buffer_size): """Open a channel for passing objects between tasks within a process. @@ -40,9 +48,8 @@ def open_memory_channel(max_buffer_size): see :ref:`channel-buffering` for more details. If in doubt, use 0. Returns: - A pair ``(send_channel, receive_channel)``. If you have - trouble remembering which order these go in, remember: data - flows from left → right. + A named tuple ``(send_channel, receive_channel)``. The tuple ordering is + intended to match the image of data flowing from left → right. In addition to the standard channel methods, all memory channel objects provide a ``statistics()`` method, which returns an object with the @@ -69,7 +76,7 @@ def open_memory_channel(max_buffer_size): if max_buffer_size < 0: raise ValueError("max_buffer_size must be >= 0") state = MemoryChannelState(max_buffer_size) - return ( + return MemoryChannelPair( MemorySendChannel._create(state), MemoryReceiveChannel._create(state), ) diff --git a/trio/tests/test_channel.py b/trio/tests/test_channel.py index b43466dd7d..83fd746bdc 100644 --- a/trio/tests/test_channel.py +++ b/trio/tests/test_channel.py @@ -350,3 +350,8 @@ async def do_send(s, v): assert await r.receive() == 1 with pytest.raises(trio.WouldBlock): r.receive_nowait() + + +def test_named_tuple(): + pair = open_memory_channel(0) + assert pair.send_channel, pair.receive_channel == pair diff --git a/trio/tests/test_highlevel_serve_listeners.py b/trio/tests/test_highlevel_serve_listeners.py index b028092eb9..7925a16ff4 100644 --- a/trio/tests/test_highlevel_serve_listeners.py +++ b/trio/tests/test_highlevel_serve_listeners.py @@ -19,7 +19,7 @@ class MemoryListener(trio.abc.Listener): async def connect(self): assert not self.closed client, server = memory_stream_pair() - await self.queued_streams[0].send(server) + await self.queued_streams.send_channel.send(server) return client async def accept(self): @@ -27,7 +27,7 @@ async def accept(self): assert not self.closed if self.accept_hook is not None: await self.accept_hook() - stream = await self.queued_streams[1].receive() + stream = await self.queued_streams.receive_channel.receive() self.accepted_streams.append(stream) return stream