diff --git a/docs/topics/keepalive.rst b/docs/topics/keepalive.rst index e63c2f8f5..fd8300183 100644 --- a/docs/topics/keepalive.rst +++ b/docs/topics/keepalive.rst @@ -136,8 +136,8 @@ measured during the last exchange of Ping and Pong frames:: Alternatively, you can measure the latency at any time by calling :attr:`~asyncio.connection.Connection.ping` and awaiting its result:: - pong_waiter = await websocket.ping() - latency = await pong_waiter + pong_received = await websocket.ping() + latency = await pong_received Latency between a client and a server may increase for two reasons: diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 61c300d63..e7af71fc5 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -101,10 +101,10 @@ def __init__( self.close_deadline: float | None = None # Protect sending fragmented messages. - self.fragmented_send_waiter: asyncio.Future[None] | None = None + self.send_in_progress: asyncio.Future[None] | None = None # Mapping of ping IDs to pong waiters, in chronological order. - self.pong_waiters: dict[bytes, tuple[asyncio.Future[float], float]] = {} + self.pending_pings: dict[bytes, tuple[asyncio.Future[float], float]] = {} self.latency: float = 0 """ @@ -468,8 +468,8 @@ async def send( """ # While sending a fragmented message, prevent sending other messages # until all fragments are sent. - while self.fragmented_send_waiter is not None: - await asyncio.shield(self.fragmented_send_waiter) + while self.send_in_progress is not None: + await asyncio.shield(self.send_in_progress) # Unfragmented message -- this case must be handled first because # strings and bytes-like objects are iterable. @@ -502,8 +502,8 @@ async def send( except StopIteration: return - assert self.fragmented_send_waiter is None - self.fragmented_send_waiter = self.loop.create_future() + assert self.send_in_progress is None + self.send_in_progress = self.loop.create_future() try: # First fragment. if isinstance(chunk, str): @@ -549,8 +549,8 @@ async def send( raise finally: - self.fragmented_send_waiter.set_result(None) - self.fragmented_send_waiter = None + self.send_in_progress.set_result(None) + self.send_in_progress = None # Fragmented message -- async iterator. @@ -561,8 +561,8 @@ async def send( except StopAsyncIteration: return - assert self.fragmented_send_waiter is None - self.fragmented_send_waiter = self.loop.create_future() + assert self.send_in_progress is None + self.send_in_progress = self.loop.create_future() try: # First fragment. if isinstance(chunk, str): @@ -610,8 +610,8 @@ async def send( raise finally: - self.fragmented_send_waiter.set_result(None) - self.fragmented_send_waiter = None + self.send_in_progress.set_result(None) + self.send_in_progress = None else: raise TypeError("data must be str, bytes, iterable, or async iterable") @@ -635,7 +635,7 @@ async def close(self, code: int = 1000, reason: str = "") -> None: # The context manager takes care of waiting for the TCP connection # to terminate after calling a method that sends a close frame. async with self.send_context(): - if self.fragmented_send_waiter is not None: + if self.send_in_progress is not None: self.protocol.fail( CloseCode.INTERNAL_ERROR, "close during fragmented message", @@ -677,9 +677,9 @@ async def ping(self, data: Data | None = None) -> Awaitable[float]: :: - pong_waiter = await ws.ping() + pong_received = await ws.ping() # only if you want to wait for the corresponding pong - latency = await pong_waiter + latency = await pong_received Raises: ConnectionClosed: When the connection is closed. @@ -696,19 +696,19 @@ async def ping(self, data: Data | None = None) -> Awaitable[float]: async with self.send_context(): # Protect against duplicates if a payload is explicitly set. - if data in self.pong_waiters: + if data in self.pending_pings: raise ConcurrencyError("already waiting for a pong with the same data") # Generate a unique random payload otherwise. - while data is None or data in self.pong_waiters: + while data is None or data in self.pending_pings: data = struct.pack("!I", random.getrandbits(32)) - pong_waiter = self.loop.create_future() + pong_received = self.loop.create_future() # The event loop's default clock is time.monotonic(). Its resolution # is a bit low on Windows (~16ms). This is improved in Python 3.13. - self.pong_waiters[data] = (pong_waiter, self.loop.time()) + self.pending_pings[data] = (pong_received, self.loop.time()) self.protocol.send_ping(data) - return pong_waiter + return pong_received async def pong(self, data: Data = b"") -> None: """ @@ -757,7 +757,7 @@ def acknowledge_pings(self, data: bytes) -> None: """ # Ignore unsolicited pong. - if data not in self.pong_waiters: + if data not in self.pending_pings: return pong_timestamp = self.loop.time() @@ -766,20 +766,20 @@ def acknowledge_pings(self, data: bytes) -> None: # Acknowledge all previous pings too in that case. ping_id = None ping_ids = [] - for ping_id, (pong_waiter, ping_timestamp) in self.pong_waiters.items(): + for ping_id, (pong_received, ping_timestamp) in self.pending_pings.items(): ping_ids.append(ping_id) latency = pong_timestamp - ping_timestamp - if not pong_waiter.done(): - pong_waiter.set_result(latency) + if not pong_received.done(): + pong_received.set_result(latency) if ping_id == data: self.latency = latency break else: raise AssertionError("solicited pong not found in pings") - # Remove acknowledged pings from self.pong_waiters. + # Remove acknowledged pings from self.pending_pings. for ping_id in ping_ids: - del self.pong_waiters[ping_id] + del self.pending_pings[ping_id] def abort_pings(self) -> None: """ @@ -791,16 +791,16 @@ def abort_pings(self) -> None: assert self.protocol.state is CLOSED exc = self.protocol.close_exc - for pong_waiter, _ping_timestamp in self.pong_waiters.values(): - if not pong_waiter.done(): - pong_waiter.set_exception(exc) + for pong_received, _ping_timestamp in self.pending_pings.values(): + if not pong_received.done(): + pong_received.set_exception(exc) # If the exception is never retrieved, it will be logged when ping # is garbage-collected. This is confusing for users. # Given that ping is done (with an exception), canceling it does # nothing, but it prevents logging the exception. - pong_waiter.cancel() + pong_received.cancel() - self.pong_waiters.clear() + self.pending_pings.clear() async def keepalive(self) -> None: """ @@ -821,7 +821,7 @@ async def keepalive(self) -> None: # connection to be closed before raising ConnectionClosed. # However, connection_lost() cancels keepalive_task before # it gets a chance to resume excuting. - pong_waiter = await self.ping() + pong_received = await self.ping() if self.debug: self.logger.debug("% sent keepalive ping") @@ -830,9 +830,9 @@ async def keepalive(self) -> None: async with asyncio_timeout(self.ping_timeout): # connection_lost cancels keepalive immediately # after setting a ConnectionClosed exception on - # pong_waiter. A CancelledError is raised here, + # pong_received. A CancelledError is raised here, # not a ConnectionClosed exception. - latency = await pong_waiter + latency = await pong_received self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: if self.debug: @@ -1201,7 +1201,7 @@ def broadcast( if connection.protocol.state is not OPEN: continue - if connection.fragmented_send_waiter is not None: + if connection.send_in_progress is not None: if raise_exceptions: exception = ConcurrencyError("sending a fragmented message") exceptions.append(exception) diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 1fd41811c..1fe33f709 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -81,8 +81,7 @@ class Assembler: """ - # coverage reports incorrectly: "line NN didn't jump to the function exit" - def __init__( # pragma: no cover + def __init__( self, high: int | None = None, low: int | None = None, diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index bedbf4def..04f65f3f6 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -104,7 +104,7 @@ def __init__( self.send_in_progress = False # Mapping of ping IDs to pong waiters, in chronological order. - self.pong_waiters: dict[bytes, tuple[threading.Event, float, bool]] = {} + self.pending_pings: dict[bytes, tuple[threading.Event, float, bool]] = {} self.latency: float = 0 """ @@ -629,8 +629,9 @@ def ping( :: - pong_event = ws.ping() - pong_event.wait() # only if you want to wait for the pong + pong_received = ws.ping() + # only if you want to wait for the corresponding pong + pong_received.wait() Raises: ConnectionClosed: When the connection is closed. @@ -647,17 +648,17 @@ def ping( with self.send_context(): # Protect against duplicates if a payload is explicitly set. - if data in self.pong_waiters: + if data in self.pending_pings: raise ConcurrencyError("already waiting for a pong with the same data") # Generate a unique random payload otherwise. - while data is None or data in self.pong_waiters: + while data is None or data in self.pending_pings: data = struct.pack("!I", random.getrandbits(32)) - pong_waiter = threading.Event() - self.pong_waiters[data] = (pong_waiter, time.monotonic(), ack_on_close) + pong_received = threading.Event() + self.pending_pings[data] = (pong_received, time.monotonic(), ack_on_close) self.protocol.send_ping(data) - return pong_waiter + return pong_received def pong(self, data: Data = b"") -> None: """ @@ -707,7 +708,7 @@ def acknowledge_pings(self, data: bytes) -> None: """ with self.protocol_mutex: # Ignore unsolicited pong. - if data not in self.pong_waiters: + if data not in self.pending_pings: return pong_timestamp = time.monotonic() @@ -717,21 +718,21 @@ def acknowledge_pings(self, data: bytes) -> None: ping_id = None ping_ids = [] for ping_id, ( - pong_waiter, + pong_received, ping_timestamp, _ack_on_close, - ) in self.pong_waiters.items(): + ) in self.pending_pings.items(): ping_ids.append(ping_id) - pong_waiter.set() + pong_received.set() if ping_id == data: self.latency = pong_timestamp - ping_timestamp break else: raise AssertionError("solicited pong not found in pings") - # Remove acknowledged pings from self.pong_waiters. + # Remove acknowledged pings from self.pending_pings. for ping_id in ping_ids: - del self.pong_waiters[ping_id] + del self.pending_pings[ping_id] def acknowledge_pending_pings(self) -> None: """ @@ -740,11 +741,11 @@ def acknowledge_pending_pings(self) -> None: """ assert self.protocol.state is CLOSED - for pong_waiter, _ping_timestamp, ack_on_close in self.pong_waiters.values(): + for pong_received, _ping_timestamp, ack_on_close in self.pending_pings.values(): if ack_on_close: - pong_waiter.set() + pong_received.set() - self.pong_waiters.clear() + self.pending_pings.clear() def keepalive(self) -> None: """ @@ -762,15 +763,14 @@ def keepalive(self) -> None: break try: - pong_waiter = self.ping(ack_on_close=True) + pong_received = self.ping(ack_on_close=True) except ConnectionClosed: break if self.debug: self.logger.debug("% sent keepalive ping") if self.ping_timeout is not None: - # - if pong_waiter.wait(self.ping_timeout): + if pong_received.wait(self.ping_timeout): if self.debug: self.logger.debug("% received keepalive pong") else: @@ -804,7 +804,7 @@ def recv_events(self) -> None: Run this method in a thread as long as the connection is alive. - ``recv_events()`` exits immediately when the ``self.socket`` is closed. + ``recv_events()`` exits immediately when ``self.socket`` is closed. """ try: @@ -979,6 +979,7 @@ def send_context( # Minor layering violation: we assume that the connection # will be closing soon if it isn't in the expected state. wait_for_close = True + # TODO: calculate close deadline if not set? raise_close_exc = True # To avoid a deadlock, release the connection lock by exiting the diff --git a/src/websockets/trio/__init__.py b/src/websockets/trio/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/websockets/trio/connection.py b/src/websockets/trio/connection.py new file mode 100644 index 000000000..2a77749b5 --- /dev/null +++ b/src/websockets/trio/connection.py @@ -0,0 +1,1114 @@ +from __future__ import annotations + +import contextlib +import logging +import random +import struct +import uuid +from collections.abc import AsyncIterable, AsyncIterator, Iterable, Mapping +from types import TracebackType +from typing import Any, Literal, overload + +import trio + +from ..asyncio.compatibility import ( + TimeoutError, + aiter, + anext, +) +from ..exceptions import ( + ConcurrencyError, + ConnectionClosed, + ConnectionClosedOK, + ProtocolError, +) +from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode +from ..http11 import Request, Response +from ..protocol import CLOSED, OPEN, Event, Protocol, State +from ..typing import Data, LoggerLike, Subprotocol +from .messages import Assembler + + +__all__ = ["Connection"] + + +class Connection: + """ + :mod:`trio` implementation of a WebSocket connection. + + :class:`Connection` provides APIs shared between WebSocket servers and + clients. + + You shouldn't use it directly. Instead, use + :class:`~websockets.trio.client.ClientConnection` or + :class:`~websockets.trio.server.ServerConnection`. + + """ + + def __init__( + self, + nursery: trio.Nursery, + stream: trio.abc.Stream, + protocol: Protocol, + *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = 10, + max_queue: int | None | tuple[int | None, int | None] = 16, + ) -> None: + self.nursery = nursery + self.stream = stream + self.protocol = protocol + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout + self.close_timeout = close_timeout + self.max_queue: tuple[int | None, int | None] + if isinstance(max_queue, int) or max_queue is None: + self.max_queue = (max_queue, None) + else: + self.max_queue = max_queue + + # Inject reference to this instance in the protocol's logger. + self.protocol.logger = logging.LoggerAdapter( + self.protocol.logger, + {"websocket": self}, + ) + + # Copy attributes from the protocol for convenience. + self.id: uuid.UUID = self.protocol.id + """Unique identifier of the connection. Useful in logs.""" + self.logger: LoggerLike = self.protocol.logger + """Logger for this connection.""" + self.debug = self.protocol.debug + + # HTTP handshake request and response. + self.request: Request | None = None + """Opening handshake request.""" + self.response: Response | None = None + """Opening handshake response.""" + + # Lock stopping reads when the assembler buffer is full. + self.recv_flow_control = trio.Lock() + + # Assembler turning frames into messages and serializing reads. + self.recv_messages = Assembler( + *self.max_queue, + pause=self.recv_flow_control.acquire_nowait, + resume=self.recv_flow_control.release, + ) + + # Deadline for the closing handshake. + self.close_deadline: float | None = None + + # Protect sending fragmented messages. + self.send_in_progress: trio.Event | None = None + + # Mapping of ping IDs to pong waiters, in chronological order. + self.pending_pings: dict[bytes, tuple[trio.Event, float, bool]] = {} + + self.latency: float = 0 + """ + Latency of the connection, in seconds. + + Latency is defined as the round-trip time of the connection. It is + measured by sending a Ping frame and waiting for a matching Pong frame. + Before the first measurement, :attr:`latency` is ``0``. + + By default, websockets enables a :ref:`keepalive ` mechanism + that sends Ping frames automatically at regular intervals. You can also + send Ping frames and measure latency with :meth:`ping`. + """ + + # Exception raised while reading from the connection, to be chained to + # ConnectionClosed in order to show why the TCP connection dropped. + self.recv_exc: BaseException | None = None + + # Start recv_events only after all attributes are initialized. + self.nursery.start_soon(self.recv_events) + + # Completed when the TCP connection is closed and the WebSocket + # connection state becomes CLOSED. + self.stream_closed: trio.Event = trio.Event() + + # Public attributes + + @property + def local_address(self) -> Any: + """ + Local address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family. + See :meth:`~socket.socket.getsockname`. + + """ + if isinstance(self.stream, trio.SSLStream): # pragma: no cover + stream = self.stream.transport_stream + else: + stream = self.stream + if isinstance(stream, trio.SocketStream): + return stream.socket.getsockname() + else: # pragma: no cover + raise NotImplementedError + + @property + def remote_address(self) -> Any: + """ + Remote address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family. + See :meth:`~socket.socket.getpeername`. + + """ + if isinstance(self.stream, trio.SSLStream): # pragma: no cover + stream = self.stream.transport_stream + else: + stream = self.stream + if isinstance(stream, trio.SocketStream): + return stream.socket.getpeername() + else: # pragma: no cover + raise NotImplementedError + + @property + def state(self) -> State: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should call :meth:`~recv` or + :meth:`send` and handle :exc:`~websockets.exceptions.ConnectionClosed` + exceptions. + + """ + return self.protocol.state + + @property + def subprotocol(self) -> Subprotocol | None: + """ + Subprotocol negotiated during the opening handshake. + + :obj:`None` if no subprotocol was negotiated. + + """ + return self.protocol.subprotocol + + @property + def close_code(self) -> int | None: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should inspect attributes + of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.close_code + + @property + def close_reason(self) -> str | None: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should inspect attributes + of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.close_reason + + # Public methods + + async def __aenter__(self) -> Connection: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + if exc_type is None: + await self.close() + else: + await self.close(CloseCode.INTERNAL_ERROR) + + async def __aiter__(self) -> AsyncIterator[Data]: + """ + Iterate on incoming messages. + + The iterator calls :meth:`recv` and yields messages asynchronously in an + infinite loop. + + It exits when the connection is closed normally. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` exception after a + protocol error or a network failure. + + """ + try: + while True: + yield await self.recv() + except ConnectionClosedOK: + return + + @overload + async def recv(self, decode: Literal[True]) -> str: ... + + @overload + async def recv(self, decode: Literal[False]) -> bytes: ... + + @overload + async def recv(self, decode: bool | None = None) -> Data: ... + + async def recv(self, decode: bool | None = None) -> Data: + """ + Receive the next message. + + When the connection is closed, :meth:`recv` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises + :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal closure + and :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. This is how you detect the end of the + message stream. + + Canceling :meth:`recv` is safe. There's no risk of losing data. The next + invocation of :meth:`recv` will return the next message. + + This makes it possible to enforce a timeout by wrapping :meth:`recv` in + :func:`~trio.move_on_after` or :func:`~trio.fail_after`. + + When the message is fragmented, :meth:`recv` waits until all fragments + are received, reassembles them, and returns the whole message. + + Args: + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + + Returns: + A string (:class:`str`) for a Text_ frame or a bytestring + (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and + return a bytestring (:class:`bytes`). This improves performance + when decoding isn't needed, for example if the message contains + JSON and you're using a JSON library that expects a bytestring. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames + and return a string (:class:`str`). This may be useful for + servers that send binary frames instead of text frames. + + Raises: + ConnectionClosed: When the connection is closed. + ConcurrencyError: If two coroutines call :meth:`recv` or + :meth:`recv_streaming` concurrently. + + """ + try: + return await self.recv_messages.get(decode) + except EOFError: + pass + # fallthrough + except ConcurrencyError: + raise ConcurrencyError( + "cannot call recv while another coroutine " + "is already running recv or recv_streaming" + ) from None + except UnicodeDecodeError as exc: + async with self.send_context(): + self.protocol.fail( + CloseCode.INVALID_DATA, + f"{exc.reason} at position {exc.start}", + ) + # fallthrough + + # Wait for the protocol state to be CLOSED before accessing close_exc. + await self.stream_closed.wait() + raise self.protocol.close_exc from self.recv_exc + + @overload + def recv_streaming(self, decode: Literal[True]) -> AsyncIterator[str]: ... + + @overload + def recv_streaming(self, decode: Literal[False]) -> AsyncIterator[bytes]: ... + + @overload + def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: ... + + async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: + """ + Receive the next message frame by frame. + + This method is designed for receiving fragmented messages. It returns an + asynchronous iterator that yields each fragment as it is received. This + iterator must be fully consumed. Else, future calls to :meth:`recv` or + :meth:`recv_streaming` will raise + :exc:`~websockets.exceptions.ConcurrencyError`, making the connection + unusable. + + :meth:`recv_streaming` raises the same exceptions as :meth:`recv`. + + Canceling :meth:`recv_streaming` before receiving the first frame is + safe. Canceling it after receiving one or more frames leaves the + iterator in a partially consumed state, making the connection unusable. + Instead, you should close the connection with :meth:`close`. + + Args: + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + + Returns: + An iterator of strings (:class:`str`) for a Text_ frame or + bytestrings (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames + and return bytestrings (:class:`bytes`). This may be useful to + optimize performance when decoding isn't needed. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames + and return strings (:class:`str`). This is useful for servers + that send binary frames instead of text frames. + + Raises: + ConnectionClosed: When the connection is closed. + ConcurrencyError: If two coroutines call :meth:`recv` or + :meth:`recv_streaming` concurrently. + + """ + try: + async for frame in self.recv_messages.get_iter(decode): + yield frame + return + except EOFError: + pass + # fallthrough + except ConcurrencyError: + raise ConcurrencyError( + "cannot call recv_streaming while another coroutine " + "is already running recv or recv_streaming" + ) from None + except UnicodeDecodeError as exc: + async with self.send_context(): + self.protocol.fail( + CloseCode.INVALID_DATA, + f"{exc.reason} at position {exc.start}", + ) + # fallthrough + + # Wait for the protocol state to be CLOSED before accessing close_exc. + await self.stream_closed.wait() + raise self.protocol.close_exc from self.recv_exc + + async def send( + self, + message: Data | Iterable[Data] | AsyncIterable[Data], + text: bool | None = None, + ) -> None: + """ + Send a message. + + A string (:class:`str`) is sent as a Text_ frame. A bytestring or + bytes-like object (:class:`bytes`, :class:`bytearray`, or + :class:`memoryview`) is sent as a Binary_ frame. + + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + You may override this behavior with the ``text`` argument: + + * Set ``text=True`` to send a bytestring or bytes-like object + (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) as a + Text_ frame. This improves performance when the message is already + UTF-8 encoded, for example if the message contains JSON and you're + using a JSON library that produces a bytestring. + * Set ``text=False`` to send a string (:class:`str`) in a Binary_ + frame. This may be useful for servers that expect binary frames + instead of text frames. + + :meth:`send` also accepts an iterable or an asynchronous iterable of + strings, bytestrings, or bytes-like objects to enable fragmentation_. + Each item is treated as a message fragment and sent in its own frame. + All items must be of the same type, or else :meth:`send` will raise a + :exc:`TypeError` and the connection will be closed. + + .. _fragmentation: https://datatracker.ietf.org/doc/html/rfc6455#section-5.4 + + :meth:`send` rejects dict-like objects because this is often an error. + (If you really want to send the keys of a dict-like object as fragments, + call its :meth:`~dict.keys` method and pass the result to :meth:`send`.) + + Canceling :meth:`send` is discouraged. Instead, you should close the + connection with :meth:`close`. Indeed, there are only two situations + where :meth:`send` may yield control to the event loop and then get + canceled; in both cases, :meth:`close` has the same effect and is + more clear: + + 1. The write buffer is full. If you don't want to wait until enough + data is sent, your only alternative is to close the connection. + :meth:`close` will likely time out then abort the TCP connection. + 2. ``message`` is an asynchronous iterator that yields control. + Stopping in the middle of a fragmented message will cause a + protocol error and the connection will be closed. + + When the connection is closed, :meth:`send` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it + raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal + connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. + + Args: + message: Message to send. + + Raises: + ConnectionClosed: When the connection is closed. + TypeError: If ``message`` doesn't have a supported type. + + """ + # While sending a fragmented message, prevent sending other messages + # until all fragments are sent. + while self.send_in_progress is not None: + await self.send_in_progress.wait() + + # Unfragmented message -- this case must be handled first because + # strings and bytes-like objects are iterable. + + if isinstance(message, str): + async with self.send_context(): + if text is False: + self.protocol.send_binary(message.encode()) + else: + self.protocol.send_text(message.encode()) + + elif isinstance(message, BytesLike): + async with self.send_context(): + if text is True: + self.protocol.send_text(message) + else: + self.protocol.send_binary(message) + + # Catch a common mistake -- passing a dict to send(). + + elif isinstance(message, Mapping): + raise TypeError("data is a dict-like object") + + # Fragmented message -- regular iterator. + + elif isinstance(message, Iterable): + chunks = iter(message) + try: + chunk = next(chunks) + except StopIteration: + return + + assert self.send_in_progress is None + self.send_in_progress = trio.Event() + try: + # First fragment. + if isinstance(chunk, str): + async with self.send_context(): + if text is False: + self.protocol.send_binary(chunk.encode(), fin=False) + else: + self.protocol.send_text(chunk.encode(), fin=False) + encode = True + elif isinstance(chunk, BytesLike): + async with self.send_context(): + if text is True: + self.protocol.send_text(chunk, fin=False) + else: + self.protocol.send_binary(chunk, fin=False) + encode = False + else: + raise TypeError("iterable must contain bytes or str") + + # Other fragments + for chunk in chunks: + if isinstance(chunk, str) and encode: + async with self.send_context(): + self.protocol.send_continuation(chunk.encode(), fin=False) + elif isinstance(chunk, BytesLike) and not encode: + async with self.send_context(): + self.protocol.send_continuation(chunk, fin=False) + else: + raise TypeError("iterable must contain uniform types") + + # Final fragment. + async with self.send_context(): + self.protocol.send_continuation(b"", fin=True) + + except Exception: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + async with self.send_context(): + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "error in fragmented message", + ) + raise + + finally: + self.send_in_progress.set() + self.send_in_progress = None + + # Fragmented message -- async iterator. + + elif isinstance(message, AsyncIterable): + achunks = aiter(message) + try: + chunk = await anext(achunks) + except StopAsyncIteration: + return + + assert self.send_in_progress is None + self.send_in_progress = trio.Event() + try: + # First fragment. + if isinstance(chunk, str): + if text is False: + async with self.send_context(): + self.protocol.send_binary(chunk.encode(), fin=False) + else: + async with self.send_context(): + self.protocol.send_text(chunk.encode(), fin=False) + encode = True + elif isinstance(chunk, BytesLike): + if text is True: + async with self.send_context(): + self.protocol.send_text(chunk, fin=False) + else: + async with self.send_context(): + self.protocol.send_binary(chunk, fin=False) + encode = False + else: + raise TypeError("async iterable must contain bytes or str") + + # Other fragments + async for chunk in achunks: + if isinstance(chunk, str) and encode: + async with self.send_context(): + self.protocol.send_continuation(chunk.encode(), fin=False) + elif isinstance(chunk, BytesLike) and not encode: + async with self.send_context(): + self.protocol.send_continuation(chunk, fin=False) + else: + raise TypeError("async iterable must contain uniform types") + + # Final fragment. + async with self.send_context(): + self.protocol.send_continuation(b"", fin=True) + + except Exception: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + async with self.send_context(): + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "error in fragmented message", + ) + raise + + finally: + self.send_in_progress.set() + self.send_in_progress = None + + else: + raise TypeError("data must be str, bytes, iterable, or async iterable") + + async def close(self, code: int = 1000, reason: str = "") -> None: + """ + Perform the closing handshake. + + :meth:`close` waits for the other end to complete the handshake and + for the TCP connection to terminate. + + :meth:`close` is idempotent: it doesn't do anything once the + connection is closed. + + Args: + code: WebSocket close code. + reason: WebSocket close reason. + + """ + try: + # The context manager takes care of waiting for the TCP connection + # to terminate after calling a method that sends a close frame. + async with self.send_context(): + if self.send_in_progress is not None: + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "close during fragmented message", + ) + else: + self.protocol.send_close(code, reason) + except ConnectionClosed: + # Ignore ConnectionClosed exceptions raised from send_context(). + # They mean that the connection is closed, which was the goal. + pass + + async def wait_closed(self) -> None: + """ + Wait until the connection is closed. + + :meth:`wait_closed` waits for the closing handshake to complete and for + the TCP connection to terminate. + + """ + await self.stream_closed.wait() + + async def ping( + self, data: Data | None = None, ack_on_close: bool = False + ) -> trio.Event: + """ + Send a Ping_. + + .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 + + A ping may serve as a keepalive or as a check that the remote endpoint + received all messages up to this point + + Args: + data: Payload of the ping. A :class:`str` will be encoded to UTF-8. + If ``data`` is :obj:`None`, the payload is four random bytes. + ack_on_close: when this option is :obj:`True`, the event will also + be set when the connection is closed. While this avoids getting + stuck waiting for a pong that will never arrive, it requires + checking that the state of the connection is still ``OPEN`` to + confirm that a pong was received, rather than the connection + being closed. + + Returns: + An event that will be set when the corresponding pong is received. + You can ignore it if you don't intend to wait. + + :: + + pong_received = await ws.ping() + # only if you want to wait for the corresponding pong + await pong_received.wait() + + Raises: + ConnectionClosed: When the connection is closed. + ConcurrencyError: If another ping was sent with the same data and + the corresponding pong wasn't received yet. + + """ + if isinstance(data, BytesLike): + data = bytes(data) + elif isinstance(data, str): + data = data.encode() + elif data is not None: + raise TypeError("data must be str or bytes-like") + + async with self.send_context(): + # Protect against duplicates if a payload is explicitly set. + if data in self.pending_pings: + raise ConcurrencyError("already waiting for a pong with the same data") + + # Generate a unique random payload otherwise. + while data is None or data in self.pending_pings: + data = struct.pack("!I", random.getrandbits(32)) + + pong_received = trio.Event() + self.pending_pings[data] = ( + pong_received, + trio.current_time(), + ack_on_close, + ) + self.protocol.send_ping(data) + return pong_received + + async def pong(self, data: Data = b"") -> None: + """ + Send a Pong_. + + .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 + + An unsolicited pong may serve as a unidirectional heartbeat. + + Args: + data: Payload of the pong. A :class:`str` will be encoded to UTF-8. + + Raises: + ConnectionClosed: When the connection is closed. + + """ + if isinstance(data, BytesLike): + data = bytes(data) + elif isinstance(data, str): + data = data.encode() + else: + raise TypeError("data must be str or bytes-like") + + async with self.send_context(): + self.protocol.send_pong(data) + + # Private methods + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + This method is overridden in subclasses to handle the handshake. + + """ + assert isinstance(event, Frame) + if event.opcode in DATA_OPCODES: + self.recv_messages.put(event) + + if event.opcode is Opcode.PONG: + self.acknowledge_pings(bytes(event.data)) + + def acknowledge_pings(self, data: bytes) -> None: + """ + Acknowledge pings when receiving a pong. + + """ + # Ignore unsolicited pong. + if data not in self.pending_pings: + return + + pong_timestamp = trio.current_time() + + # Sending a pong for only the most recent ping is legal. + # Acknowledge all previous pings too in that case. + ping_id = None + ping_ids = [] + for ping_id, ( + pong_received, + ping_timestamp, + _ack_on_close, + ) in self.pending_pings.items(): + ping_ids.append(ping_id) + pong_received.set() + if ping_id == data: + self.latency = pong_timestamp - ping_timestamp + break + else: + raise AssertionError("solicited pong not found in pings") + + # Remove acknowledged pings from self.pending_pings. + for ping_id in ping_ids: + del self.pending_pings[ping_id] + + def acknowledge_pending_pings(self) -> None: + """ + Acknowledge pending pings when the connection is closed. + + """ + assert self.protocol.state is CLOSED + + for pong_received, _ping_timestamp, ack_on_close in self.pending_pings.values(): + if ack_on_close: + pong_received.set() + + self.pending_pings.clear() + + async def keepalive(self) -> None: + """ + Send a Ping frame and wait for a Pong frame at regular intervals. + + """ + assert self.ping_interval is not None + latency = 0.0 + try: + while True: + # If self.ping_timeout > latency > self.ping_interval, + # pings will be sent immediately after receiving pongs. + # The period will be longer than self.ping_interval. + with trio.move_on_after(self.ping_interval - latency): + await self.stream_closed.wait() + break + + try: + pong_received = await self.ping(ack_on_close=True) + except ConnectionClosed: + break + if self.debug: + self.logger.debug("% sent keepalive ping") + + if self.ping_timeout is not None: + with trio.move_on_after(self.ping_timeout) as cancel_scope: + await pong_received.wait() + self.logger.debug("% received keepalive pong") + if cancel_scope.cancelled_caught: + if self.debug: + self.logger.debug("- timed out waiting for keepalive pong") + async with self.send_context(): + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "keepalive ping timeout", + ) + break + except Exception: + self.logger.error("keepalive ping failed", exc_info=True) + + def start_keepalive(self) -> None: + """ + Run :meth:`keepalive` in a task, unless keepalive is disabled. + + """ + if self.ping_interval is not None: + self.nursery.start_soon(self.keepalive) + + async def recv_events(self) -> None: + """ + Read incoming data from the stream and process events. + + Run this method in a task as long as the connection is alive. + + ``recv_events()`` exits immediately when ``self.stream`` is closed. + + """ + try: + while True: + try: + data = await self.stream.receive_some() + except Exception as exc: + if self.debug: + self.logger.debug( + "! error while receiving data", + exc_info=True, + ) + # When the closing handshake is initiated by our side, + # recv() may block until send_context() closes the stream. + # In that case, send_context() already set recv_exc. + # Calling set_recv_exc() avoids overwriting it. + self.set_recv_exc(exc) + break + + if data == b"": + break + + # Feed incoming data to the protocol. + self.protocol.receive_data(data) + + # This isn't expected to raise an exception. + events = self.protocol.events_received() + + # Write outgoing data to the socket. + try: + await self.send_data() + except Exception as exc: + if self.debug: + self.logger.debug( + "! error while sending data", + exc_info=True, + ) + # Similarly to the above, avoid overriding an exception + # set by send_context(), in case of a race condition + # i.e. send_context() closes the transport after recv() + # returns above but before send_data() calls send(). + self.set_recv_exc(exc) + break + + if self.protocol.close_expected(): + # If the connection is expected to close soon, set the + # close deadline based on the close timeout. + if self.close_timeout is not None: + if self.close_deadline is None: + self.close_deadline = ( + trio.current_time() + self.close_timeout + ) + + # If self.send_data raised an exception, then events are lost. + # Given that automatic responses write small amounts of data, + # this should be uncommon, so we don't handle the edge case. + + for event in events: + # This isn't expected to raise an exception. + self.process_event(event) + + # Breaking out of the while True: ... loop means that we believe + # that the socket doesn't work anymore. + # Feed the end of the data stream to the protocol. + self.protocol.receive_eof() + + # This isn't expected to raise an exception. + events = self.protocol.events_received() + + # There is no error handling because send_data() can only write + # the end of the data stream here and it handles errors itself. + await self.send_data() + + # This code path is triggered when receiving an HTTP response + # without a Content-Length header. This is the only case where + # reading until EOF generates an event; all other events have + # a known length. Ignore for coverage measurement because tests + # are in test_client.py rather than test_connection.py. + for event in events: # pragma: no cover + # This isn't expected to raise an exception. + self.process_event(event) + + except Exception as exc: + # This branch should never run. It's a safety net in case of bugs. + self.logger.error("unexpected internal error", exc_info=True) + self.set_recv_exc(exc) + finally: + # This isn't expected to raise an exception. + await self.close_stream() + + @contextlib.asynccontextmanager + async def send_context( + self, + *, + expected_state: State = OPEN, # CONNECTING during the opening handshake + ) -> AsyncIterator[None]: + """ + Create a context for writing to the connection from user code. + + On entry, :meth:`send_context` checks that the connection is open; on + exit, it writes outgoing data to the socket:: + + async with self.send_context(): + self.protocol.send_text(message.encode()) + + When the connection isn't open on entry, when the connection is expected + to close on exit, or when an unexpected error happens, terminating the + connection, :meth:`send_context` waits until the connection is closed + then raises :exc:`~websockets.exceptions.ConnectionClosed`. + + """ + # Should we wait until the connection is closed? + wait_for_close = False + # Should we close the transport and raise ConnectionClosed? + raise_close_exc = False + # What exception should we chain ConnectionClosed to? + original_exc: BaseException | None = None + + if self.protocol.state is expected_state: + # Let the caller interact with the protocol. + try: + yield + except (ProtocolError, ConcurrencyError): + # The protocol state wasn't changed. Exit immediately. + raise + except Exception as exc: + self.logger.error("unexpected internal error", exc_info=True) + # This branch should never run. It's a safety net in case of + # bugs. Since we don't know what happened, we will close the + # connection and raise the exception to the caller. + wait_for_close = False + raise_close_exc = True + original_exc = exc + else: + # Check if the connection is expected to close soon. + if self.protocol.close_expected(): + wait_for_close = True + # If the connection is expected to close soon, set the + # close deadline based on the close timeout. + # Since we tested earlier that protocol.state was OPEN + # (or CONNECTING), self.close_deadline is still None. + if self.close_timeout is not None: + assert self.close_deadline is None + self.close_deadline = trio.current_time() + self.close_timeout + # Write outgoing data to the socket and enforce flow control. + try: + await self.send_data() + except Exception as exc: + if self.debug: + self.logger.debug("! error while sending data", exc_info=True) + # While the only expected exception here is OSError, + # other exceptions would be treated identically. + wait_for_close = False + raise_close_exc = True + original_exc = exc + + else: # self.protocol.state is not expected_state + # Minor layering violation: we assume that the connection + # will be closing soon if it isn't in the expected state. + wait_for_close = True + # Calculate close_deadline if it wasn't set yet. + if self.close_timeout is not None: + if self.close_deadline is None: + self.close_deadline = trio.current_time() + self.close_timeout + raise_close_exc = True + + # If the connection is expected to close soon and the close timeout + # elapses, close the socket to terminate the connection. + if wait_for_close: + if self.close_deadline is not None: + with trio.move_on_at(self.close_deadline) as cancel_scope: + await self.stream_closed.wait() + if cancel_scope.cancelled_caught: + # There's no risk to overwrite another error because + # original_exc is never set when wait_for_close is True. + assert original_exc is None + original_exc = TimeoutError("timed out while closing connection") + # Set recv_exc before closing the transport in order to get + # proper exception reporting. + raise_close_exc = True + self.set_recv_exc(original_exc) + else: + await self.stream_closed.wait() + + # If an error occurred, close the transport to terminate the connection and + # raise an exception. + if raise_close_exc: + await self.close_stream() + raise self.protocol.close_exc from original_exc + + async def send_data(self) -> None: + """ + Send outgoing data. + + """ + for data in self.protocol.data_to_send(): + if data: + await self.stream.send_all(data) + else: + # Half-close the TCP connection when possible i.e. no TLS. + if isinstance(self.stream, trio.abc.HalfCloseableStream): + if self.debug: + self.logger.debug("x half-closing TCP connection") + try: + await self.stream.send_eof() + except Exception: # pragma: no cover + pass + # Else, close the TCP connection. + else: # pragma: no cover + if self.debug: + self.logger.debug("x closing TCP connection") + await self.stream.aclose() + + def set_recv_exc(self, exc: BaseException | None) -> None: + """ + Set recv_exc, if not set yet. + + """ + if self.recv_exc is None: + self.recv_exc = exc + + async def close_stream(self) -> None: + """ + Shutdown and close stream. Close message assembler. + + Calling close_stream() guarantees that recv_events() terminates. Indeed, + recv_events() may block only on stream.recv() or on recv_messages.put(). + + """ + # Close the stream. + await self.stream.aclose() + + # Calling protocol.receive_eof() is safe because it's idempotent. + # This guarantees that the protocol state becomes CLOSED. + self.protocol.receive_eof() + assert self.protocol.state is CLOSED + + # Abort recv() with a ConnectionClosed exception. + self.recv_messages.close() + + # Acknowledge pings sent with the ack_on_close option. + self.acknowledge_pending_pings() + + # Unblock coroutines waiting on self.stream_closed. + self.stream_closed.set() diff --git a/src/websockets/trio/messages.py b/src/websockets/trio/messages.py new file mode 100644 index 000000000..65f9759ef --- /dev/null +++ b/src/websockets/trio/messages.py @@ -0,0 +1,285 @@ +from __future__ import annotations + +import codecs +import math +from collections.abc import AsyncIterator +from typing import Any, Callable, Literal, TypeVar, overload + +import trio + +from ..exceptions import ConcurrencyError +from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame +from ..typing import Data + + +__all__ = ["Assembler"] + +UTF8Decoder = codecs.getincrementaldecoder("utf-8") + +T = TypeVar("T") + + +class Assembler: + """ + Assemble messages from frames. + + :class:`Assembler` expects only data frames. The stream of frames must + respect the protocol; if it doesn't, the behavior is undefined. + + Args: + pause: Called when the buffer of frames goes above the high water mark; + should pause reading from the network. + resume: Called when the buffer of frames goes below the low water mark; + should resume reading from the network. + + """ + + def __init__( + self, + high: int | None = None, + low: int | None = None, + pause: Callable[[], Any] = lambda: None, + resume: Callable[[], Any] = lambda: None, + ) -> None: + # Queue of incoming frames. + self.send_frames: trio.MemorySendChannel[Frame] + self.recv_frames: trio.MemoryReceiveChannel[Frame] + self.send_frames, self.recv_frames = trio.open_memory_channel(math.inf) + + # We cannot put a hard limit on the size of the queue because a single + # call to Protocol.data_received() could produce thousands of frames, + # which must be buffered. Instead, we pause reading when the buffer goes + # above the high limit and we resume when it goes under the low limit. + if high is not None and low is None: + low = high // 4 + if high is None and low is not None: + high = low * 4 + if high is not None and low is not None: + if low < 0: + raise ValueError("low must be positive or equal to zero") + if high < low: + raise ValueError("high must be greater than or equal to low") + self.high, self.low = high, low + self.pause = pause + self.resume = resume + self.paused = False + + # This flag prevents concurrent calls to get() by user code. + self.get_in_progress = False + + # This flag marks the end of the connection. + self.closed = False + + @overload + async def get(self, decode: Literal[True]) -> str: ... + + @overload + async def get(self, decode: Literal[False]) -> bytes: ... + + @overload + async def get(self, decode: bool | None = None) -> Data: ... + + async def get(self, decode: bool | None = None) -> Data: + """ + Read the next message. + + :meth:`get` returns a single :class:`str` or :class:`bytes`. + + If the message is fragmented, :meth:`get` waits until the last frame is + received, then it reassembles the message and returns it. To receive + messages frame by frame, use :meth:`get_iter` instead. + + Args: + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. + + Raises: + EOFError: If the stream of frames has ended. + UnicodeDecodeError: If a text frame contains invalid UTF-8. + ConcurrencyError: If two coroutines run :meth:`get` or + :meth:`get_iter` concurrently. + + """ + if self.get_in_progress: + raise ConcurrencyError("get() or get_iter() is already running") + self.get_in_progress = True + + # Locking with get_in_progress prevents concurrent execution + # until get() fetches a complete message or is canceled. + + try: + # First frame + try: + frame = await self.recv_frames.receive() + except trio.EndOfChannel: + raise EOFError("stream of frames ended") + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + frames = [frame] + + # Following frames, for fragmented messages + while not frame.fin: + try: + frame = await self.recv_frames.receive() + except trio.Cancelled: + # Put frames already received back into the queue + # so that future calls to get() can return them. + assert not self.send_frames._state.receive_tasks, ( + "no task should be waiting on receive()" + ) + assert not self.send_frames._state.data, "queue should be empty" + for frame in frames: + self.send_frames.send_nowait(frame) + raise + except trio.EndOfChannel: + raise EOFError("stream of frames ended") + self.maybe_resume() + assert frame.opcode is OP_CONT + frames.append(frame) + + finally: + self.get_in_progress = False + + data = b"".join(frame.data for frame in frames) + if decode: + return data.decode() + else: + return data + + @overload + def get_iter(self, decode: Literal[True]) -> AsyncIterator[str]: ... + + @overload + def get_iter(self, decode: Literal[False]) -> AsyncIterator[bytes]: ... + + @overload + def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: ... + + async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: + """ + Stream the next message. + + Iterating the return value of :meth:`get_iter` asynchronously yields a + :class:`str` or :class:`bytes` for each frame in the message. + + The iterator must be fully consumed before calling :meth:`get_iter` or + :meth:`get` again. Else, :exc:`ConcurrencyError` is raised. + + This method only makes sense for fragmented messages. If messages aren't + fragmented, use :meth:`get` instead. + + Args: + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. + + Raises: + EOFError: If the stream of frames has ended. + UnicodeDecodeError: If a text frame contains invalid UTF-8. + ConcurrencyError: If two coroutines run :meth:`get` or + :meth:`get_iter` concurrently. + + """ + if self.get_in_progress: + raise ConcurrencyError("get() or get_iter() is already running") + self.get_in_progress = True + + # Locking with get_in_progress prevents concurrent execution + # until get_iter() fetches a complete message or is canceled. + + # If get_iter() raises an exception e.g. in decoder.decode(), + # get_in_progress remains set and the connection becomes unusable. + + # First frame + try: + frame = await self.recv_frames.receive() + except trio.Cancelled: + self.get_in_progress = False + raise + except trio.EndOfChannel: + raise EOFError("stream of frames ended") + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + if decode: + decoder = UTF8Decoder() + yield decoder.decode(frame.data, frame.fin) + else: + yield frame.data + + # Following frames, for fragmented messages + while not frame.fin: + # We cannot handle trio.Cancelled because we don't buffer + # previous fragments — we're streaming them. Canceling get_iter() + # here will leave the assembler in a stuck state. Future calls to + # get() or get_iter() will raise ConcurrencyError. + try: + frame = await self.recv_frames.receive() + except trio.EndOfChannel: + raise EOFError("stream of frames ended") + self.maybe_resume() + assert frame.opcode is OP_CONT + if decode: + yield decoder.decode(frame.data, frame.fin) + else: + yield frame.data + + self.get_in_progress = False + + def put(self, frame: Frame) -> None: + """ + Add ``frame`` to the next message. + + Raises: + EOFError: If the stream of frames has ended. + + """ + if self.closed: + raise EOFError("stream of frames ended") + + self.send_frames.send_nowait(frame) + self.maybe_pause() + + def maybe_pause(self) -> None: + """Pause the writer if queue is above the high water mark.""" + # Skip if flow control is disabled + if self.high is None: + return + + # Bypass the statistics() method for performance reasons. + # Check for "> high" to support high = 0 + if len(self.send_frames._state.data) > self.high and not self.paused: + self.paused = True + self.pause() + + def maybe_resume(self) -> None: + """Resume the writer if queue is below the low water mark.""" + # Skip if flow control is disabled + if self.low is None: + return + + # Bypass the statistics() method for performance reasons. + # Check for "<= low" to support low = 0 + if len(self.send_frames._state.data) <= self.low and self.paused: + self.paused = False + self.resume() + + def close(self) -> None: + """ + End the stream of frames. + + Calling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, + or :meth:`put` is safe. They will raise :exc:`trio.EndOfChannel`. + + """ + if self.closed: + return + + self.closed = True + + # Unblock get() or get_iter(). + self.send_frames.close() diff --git a/tests/asyncio/connection.py b/tests/asyncio/connection.py index ad1c121bf..854b9bb99 100644 --- a/tests/asyncio/connection.py +++ b/tests/asyncio/connection.py @@ -21,7 +21,7 @@ def delay_frames_sent(self, delay): """ Add a delay before sending frames. - This can result in out-of-order writes, which is unrealistic. + Misuse can result in out-of-order writes, which is unrealistic. """ assert self.transport.delay_write is None @@ -36,7 +36,7 @@ def delay_eof_sent(self, delay): """ Add a delay before sending EOF. - This can result in out-of-order writes, which is unrealistic. + Misuse can result in out-of-order writes, which is unrealistic. """ assert self.transport.delay_write_eof is None @@ -83,9 +83,9 @@ class InterceptingTransport: This is coupled to the implementation, which relies on these two methods. - Since ``write()`` and ``write_eof()`` are not coroutines, this effect is - achieved by scheduling writes at a later time, after the methods return. - This can easily result in out-of-order writes, which is unrealistic. + Since ``write()`` and ``write_eof()`` are synchronous, we can only schedule + writes at a later time, after they return. This is unrealistic and can lead + to out-of-order writes if tests aren't written carefully. """ @@ -101,15 +101,13 @@ def __getattr__(self, name): return getattr(self.transport, name) def write(self, data): - if not self.drop_write: - if self.delay_write is not None: - self.loop.call_later(self.delay_write, self.transport.write, data) - else: - self.transport.write(data) + if self.delay_write is not None: + self.loop.call_later(self.delay_write, self.transport.write, data) + elif not self.drop_write: + self.transport.write(data) def write_eof(self): - if not self.drop_write_eof: - if self.delay_write_eof is not None: - self.loop.call_later(self.delay_write_eof, self.transport.write_eof) - else: - self.transport.write_eof() + if self.delay_write_eof is not None: + self.loop.call_later(self.delay_write_eof, self.transport.write_eof) + elif not self.drop_write_eof: + self.transport.write_eof() diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 668f55cbd..29450a043 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -33,13 +33,13 @@ class ClientConnectionTests(AssertNoLogsMixin, unittest.IsolatedAsyncioTestCase) REMOTE = SERVER async def asyncSetUp(self): - loop = asyncio.get_running_loop() + self.loop = asyncio.get_running_loop() socket_, remote_socket = socket.socketpair() - self.transport, self.connection = await loop.create_connection( + self.transport, self.connection = await self.loop.create_connection( lambda: Connection(Protocol(self.LOCAL), close_timeout=2 * MS), sock=socket_, ) - self.remote_transport, self.remote_connection = await loop.create_connection( + _remote_transport, self.remote_connection = await self.loop.create_connection( lambda: InterceptingConnection(RecordingProtocol(self.REMOTE)), sock=remote_socket, ) @@ -125,41 +125,41 @@ async def test_exit_with_exception(self): async def test_aiter_text(self): """__aiter__ yields text messages.""" - aiterator = aiter(self.connection) + iterator = aiter(self.connection) await self.remote_connection.send("😀") - self.assertEqual(await anext(aiterator), "😀") + self.assertEqual(await anext(iterator), "😀") await self.remote_connection.send("😀") - self.assertEqual(await anext(aiterator), "😀") + self.assertEqual(await anext(iterator), "😀") async def test_aiter_binary(self): """__aiter__ yields binary messages.""" - aiterator = aiter(self.connection) + iterator = aiter(self.connection) await self.remote_connection.send(b"\x01\x02\xfe\xff") - self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") + self.assertEqual(await anext(iterator), b"\x01\x02\xfe\xff") await self.remote_connection.send(b"\x01\x02\xfe\xff") - self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") + self.assertEqual(await anext(iterator), b"\x01\x02\xfe\xff") async def test_aiter_mixed(self): """__aiter__ yields a mix of text and binary messages.""" - aiterator = aiter(self.connection) + iterator = aiter(self.connection) await self.remote_connection.send("😀") - self.assertEqual(await anext(aiterator), "😀") + self.assertEqual(await anext(iterator), "😀") await self.remote_connection.send(b"\x01\x02\xfe\xff") - self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") + self.assertEqual(await anext(iterator), b"\x01\x02\xfe\xff") async def test_aiter_connection_closed_ok(self): """__aiter__ terminates after a normal closure.""" - aiterator = aiter(self.connection) + iterator = aiter(self.connection) await self.remote_connection.close() with self.assertRaises(StopAsyncIteration): - await anext(aiterator) + await anext(iterator) async def test_aiter_connection_closed_error(self): """__aiter__ raises ConnectionClosedError after an error.""" - aiterator = aiter(self.connection) + iterator = aiter(self.connection) await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) with self.assertRaises(ConnectionClosedError): - await anext(aiterator) + await anext(iterator) # Test recv. @@ -245,7 +245,7 @@ async def test_recv_during_recv_streaming(self): ) async def test_recv_cancellation_before_receiving(self): - """recv can be canceled before receiving a frame.""" + """recv can be canceled before receiving a message.""" recv_task = asyncio.create_task(self.connection.recv()) await asyncio.sleep(0) # let the event loop start recv_task @@ -257,11 +257,8 @@ async def test_recv_cancellation_before_receiving(self): self.assertEqual(await self.connection.recv(), "😀") async def test_recv_cancellation_while_receiving(self): - """recv cannot be canceled after receiving a frame.""" - recv_task = asyncio.create_task(self.connection.recv()) - await asyncio.sleep(0) # let the event loop start recv_task - - gate = asyncio.get_running_loop().create_future() + """recv can be canceled while receiving a fragmented message.""" + gate = self.loop.create_future() async def fragments(): yield "⏳" @@ -269,13 +266,16 @@ async def fragments(): yield "⌛️" asyncio.create_task(self.remote_connection.send(fragments())) - await asyncio.sleep(MS) + + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(0) # let the event loop start recv_task recv_task.cancel() await asyncio.sleep(0) # let the event loop cancel recv_task - # Running recv again receives the complete message. gate.set_result(None) + + # Running recv again receives the complete message. self.assertEqual(await self.connection.recv(), "⏳⌛️") # Test recv_streaming. @@ -360,8 +360,7 @@ async def test_recv_streaming_during_recv(self): self.addCleanup(recv_task.cancel) with self.assertRaises(ConcurrencyError) as raised: - async for _ in self.connection.recv_streaming(): - self.fail("did not raise") + await alist(self.connection.recv_streaming()) self.assertEqual( str(raised.exception), "cannot call recv_streaming while another coroutine " @@ -377,8 +376,7 @@ async def test_recv_streaming_during_recv_streaming(self): self.addCleanup(recv_streaming_task.cancel) with self.assertRaises(ConcurrencyError) as raised: - async for _ in self.connection.recv_streaming(): - self.fail("did not raise") + await alist(self.connection.recv_streaming()) self.assertEqual( str(raised.exception), r"cannot call recv_streaming while another coroutine " @@ -409,7 +407,7 @@ async def test_recv_streaming_cancellation_while_receiving(self): ) await asyncio.sleep(0) # let the event loop start recv_streaming_task - gate = asyncio.get_running_loop().create_future() + gate = self.loop.create_future() async def fragments(): yield "⏳" @@ -423,6 +421,7 @@ async def fragments(): await asyncio.sleep(0) # let the event loop cancel recv_streaming_task gate.set_result(None) + # Running recv_streaming again fails. with self.assertRaises(ConcurrencyError): await alist(self.connection.recv_streaming()) @@ -555,7 +554,7 @@ async def test_send_connection_closed_error(self): async def test_send_while_send_blocked(self): """send waits for a previous call to send to complete.""" - # This test fails if the guard with fragmented_send_waiter is removed + # This test fails if the guard with send_in_progress is removed # from send() in the case when message is an Iterable. self.connection.pause_writing() asyncio.create_task(self.connection.send(["⏳", "⌛️"])) @@ -580,7 +579,7 @@ async def test_send_while_send_blocked(self): async def test_send_while_send_async_blocked(self): """send waits for a previous call to send to complete.""" - # This test fails if the guard with fragmented_send_waiter is removed + # This test fails if the guard with send_in_progress is removed # from send() in the case when message is an AsyncIterable. self.connection.pause_writing() @@ -610,9 +609,9 @@ async def fragments(): async def test_send_during_send_async(self): """send waits for a previous call to send to complete.""" - # This test fails if the guard with fragmented_send_waiter is removed + # This test fails if the guard with send_in_progress is removed # from send() in the case when message is an AsyncIterable. - gate = asyncio.get_running_loop().create_future() + gate = self.loop.create_future() async def fragments(): yield "⏳" @@ -709,8 +708,14 @@ async def test_close_explicit_code_reason(self): async def test_close_waits_for_close_frame(self): """close waits for a close frame (then EOF) before returning.""" + t0 = self.loop.time() async with self.delay_frames_rcvd(MS), self.delay_eof_rcvd(MS): await self.connection.close() + t1 = self.loop.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() @@ -724,8 +729,14 @@ async def test_close_waits_for_connection_closed(self): if self.LOCAL is SERVER: self.skipTest("only relevant on the client-side") + t0 = self.loop.time() async with self.delay_eof_rcvd(MS): await self.connection.close() + t1 = self.loop.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() @@ -738,8 +749,14 @@ async def test_close_no_timeout_waits_for_close_frame(self): """close without timeout waits for a close frame (then EOF) before returning.""" self.connection.close_timeout = None + t0 = self.loop.time() async with self.delay_frames_rcvd(MS), self.delay_eof_rcvd(MS): await self.connection.close() + t1 = self.loop.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() @@ -755,8 +772,14 @@ async def test_close_no_timeout_waits_for_connection_closed(self): self.connection.close_timeout = None + t0 = self.loop.time() async with self.delay_eof_rcvd(MS): await self.connection.close() + t1 = self.loop.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() @@ -767,8 +790,14 @@ async def test_close_no_timeout_waits_for_connection_closed(self): async def test_close_timeout_waiting_for_close_frame(self): """close times out if no close frame is received.""" + t0 = self.loop.time() async with self.drop_eof_rcvd(), self.drop_frames_rcvd(): await self.connection.close() + t1 = self.loop.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.ABNORMAL_CLOSURE) + self.assertGreater(t1 - t0, 2 * MS) with self.assertRaises(ConnectionClosedError) as raised: await self.connection.recv() @@ -782,8 +811,14 @@ async def test_close_timeout_waiting_for_connection_closed(self): if self.LOCAL is SERVER: self.skipTest("only relevant on the client-side") + t0 = self.loop.time() async with self.drop_eof_rcvd(): await self.connection.close() + t1 = self.loop.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, 2 * MS) with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() @@ -799,13 +834,9 @@ async def test_close_preserves_queued_messages(self): await self.connection.close() self.assertEqual(await self.connection.recv(), "😀") - with self.assertRaises(ConnectionClosedOK) as raised: + with self.assertRaises(ConnectionClosedOK): await self.connection.recv() - exc = raised.exception - self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") - self.assertIsNone(exc.__cause__) - async def test_close_idempotency(self): """close does nothing if the connection is already closed.""" await self.connection.close() @@ -816,11 +847,15 @@ async def test_close_idempotency(self): async def test_close_during_recv(self): """close aborts recv when called concurrently with recv.""" - recv_task = asyncio.create_task(self.connection.recv()) - await asyncio.sleep(MS) - await self.connection.close() + + async def closer(): + await asyncio.sleep(MS) + await self.connection.close() + + asyncio.create_task(closer()) + with self.assertRaises(ConnectionClosedOK) as raised: - await recv_task + await self.connection.recv() exc = raised.exception self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") @@ -828,23 +863,24 @@ async def test_close_during_recv(self): async def test_close_during_send(self): """close fails the connection when called concurrently with send.""" - gate = asyncio.get_running_loop().create_future() + close_gate = self.loop.create_future() + exit_gate = self.loop.create_future() + + async def closer(): + await close_gate + await self.connection.close() + exit_gate.set_result(None) async def fragments(): yield "⏳" - await gate + close_gate.set_result(None) + await exit_gate yield "⌛️" - send_task = asyncio.create_task(self.connection.send(fragments())) - await asyncio.sleep(MS) - - asyncio.create_task(self.connection.close()) - await asyncio.sleep(MS) - - gate.set_result(None) + asyncio.create_task(closer()) with self.assertRaises(ConnectionClosedError) as raised: - await send_task + await self.connection.send(fragments()) exc = raised.exception self.assertEqual( @@ -886,54 +922,54 @@ async def test_ping_explicit_binary(self): async def test_acknowledge_ping(self): """ping is acknowledged by a pong with the same payload.""" async with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = await self.connection.ping("this") + pong_received = await self.connection.ping("this") await self.remote_connection.pong("this") async with asyncio_timeout(MS): - await pong_waiter + await pong_received async def test_acknowledge_canceled_ping(self): """ping is acknowledged by a pong with the same payload after being canceled.""" async with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = await self.connection.ping("this") - pong_waiter.cancel() + pong_received = await self.connection.ping("this") + pong_received.cancel() await self.remote_connection.pong("this") with self.assertRaises(asyncio.CancelledError): - await pong_waiter + await pong_received async def test_acknowledge_ping_non_matching_pong(self): """ping isn't acknowledged by a pong with a different payload.""" async with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = await self.connection.ping("this") + pong_received = await self.connection.ping("this") await self.remote_connection.pong("that") with self.assertRaises(TimeoutError): async with asyncio_timeout(MS): - await pong_waiter + await pong_received async def test_acknowledge_previous_ping(self): """ping is acknowledged by a pong for a later ping.""" async with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = await self.connection.ping("this") + pong_received = await self.connection.ping("this") await self.connection.ping("that") await self.remote_connection.pong("that") async with asyncio_timeout(MS): - await pong_waiter + await pong_received async def test_acknowledge_previous_canceled_ping(self): """ping is acknowledged by a pong for a later ping after being canceled.""" async with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = await self.connection.ping("this") - pong_waiter_2 = await self.connection.ping("that") - pong_waiter.cancel() + pong_received = await self.connection.ping("this") + pong_received_2 = await self.connection.ping("that") + pong_received.cancel() await self.remote_connection.pong("that") async with asyncio_timeout(MS): - await pong_waiter_2 + await pong_received_2 with self.assertRaises(asyncio.CancelledError): - await pong_waiter + await pong_received async def test_ping_duplicate_payload(self): """ping rejects the same payload until receiving the pong.""" async with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = await self.connection.ping("idem") + pong_received = await self.connection.ping("idem") with self.assertRaises(ConcurrencyError) as raised: await self.connection.ping("idem") @@ -944,7 +980,7 @@ async def test_ping_duplicate_payload(self): await self.remote_connection.pong("idem") async with asyncio_timeout(MS): - await pong_waiter + await pong_received await self.connection.ping("idem") # doesn't raise an exception @@ -1034,6 +1070,7 @@ async def test_keepalive_terminates_while_sleeping(self): """keepalive task terminates while waiting to send a ping.""" self.connection.ping_interval = 3 * MS self.connection.start_keepalive() + self.assertFalse(self.connection.keepalive_task.done()) await asyncio.sleep(MS) await self.connection.close() self.assertTrue(self.connection.keepalive_task.done()) @@ -1062,9 +1099,9 @@ async def test_keepalive_reports_errors(self): await asyncio.sleep(2 * MS) # Exiting the context manager sleeps for 1 ms. # 3 ms: inject a fault: raise an exception in the pending pong waiter. - pong_waiter = next(iter(self.connection.pong_waiters.values()))[0] + pong_received = next(iter(self.connection.pending_pings.values()))[0] with self.assertLogs("websockets", logging.ERROR) as logs: - pong_waiter.set_exception(Exception("BOOM")) + pong_received.set_exception(Exception("BOOM")) await asyncio.sleep(0) self.assertEqual( [record.getMessage() for record in logs.records], @@ -1079,20 +1116,28 @@ async def test_keepalive_reports_errors(self): async def test_close_timeout(self): """close_timeout parameter configures close timeout.""" - connection = Connection(Protocol(self.LOCAL), close_timeout=42 * MS) + connection = Connection( + Protocol(self.LOCAL), + close_timeout=42 * MS, + ) self.assertEqual(connection.close_timeout, 42 * MS) async def test_max_queue(self): """max_queue configures high-water mark of frames buffer.""" - connection = Connection(Protocol(self.LOCAL), max_queue=4) - transport = Mock() - connection.connection_made(transport) + connection = Connection( + Protocol(self.LOCAL), + max_queue=4, + ) + connection.connection_made(Mock(spec=asyncio.Transport)) self.assertEqual(connection.recv_messages.high, 4) async def test_max_queue_none(self): """max_queue disables high-water mark of frames buffer.""" - connection = Connection(Protocol(self.LOCAL), max_queue=None) - transport = Mock() + connection = Connection( + Protocol(self.LOCAL), + max_queue=None, + ) + transport = Mock(spec=asyncio.Transport) connection.connection_made(transport) self.assertEqual(connection.recv_messages.high, None) self.assertEqual(connection.recv_messages.low, None) @@ -1103,7 +1148,7 @@ async def test_max_queue_tuple(self): Protocol(self.LOCAL), max_queue=(4, 2), ) - transport = Mock() + transport = Mock(spec=asyncio.Transport) connection.connection_made(transport) self.assertEqual(connection.recv_messages.high, 4) self.assertEqual(connection.recv_messages.low, 2) @@ -1114,7 +1159,7 @@ async def test_write_limit(self): Protocol(self.LOCAL), write_limit=4096, ) - transport = Mock() + transport = Mock(spec=asyncio.Transport) connection.connection_made(transport) transport.set_write_buffer_limits.assert_called_once_with(4096, None) @@ -1124,7 +1169,7 @@ async def test_write_limits(self): Protocol(self.LOCAL), write_limit=(4096, 2048), ) - transport = Mock() + transport = Mock(spec=asyncio.Transport) connection.connection_made(transport) transport.set_write_buffer_limits.assert_called_once_with(4096, 2048) @@ -1138,13 +1183,13 @@ async def test_logger(self): """Connection has a logger attribute.""" self.assertIsInstance(self.connection.logger, logging.LoggerAdapter) - @patch("asyncio.BaseTransport.get_extra_info", return_value=("sock", 1234)) + @patch("asyncio.Transport.get_extra_info", return_value=("sock", 1234)) async def test_local_address(self, get_extra_info): """Connection provides a local_address attribute.""" self.assertEqual(self.connection.local_address, ("sock", 1234)) get_extra_info.assert_called_with("sockname") - @patch("asyncio.BaseTransport.get_extra_info", return_value=("peer", 1234)) + @patch("asyncio.Transport.get_extra_info", return_value=("peer", 1234)) async def test_remote_address(self, get_extra_info): """Connection provides a remote_address attribute.""" self.assertEqual(self.connection.remote_address, ("peer", 1234)) @@ -1181,27 +1226,27 @@ async def test_writing_in_data_received_fails(self): # Inject a fault by shutting down the transport for writing — but not by # closing it because that would terminate the connection. self.transport.write_eof() + # Receive a ping. Responding with a pong will fail. await self.remote_connection.ping() # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: await self.connection.recv() - cause = raised.exception.__cause__ - self.assertEqual(str(cause), "Cannot call write() after write_eof()") - self.assertIsInstance(cause, RuntimeError) + + self.assertIsInstance(raised.exception.__cause__, RuntimeError) async def test_writing_in_send_context_fails(self): """Error when sending outgoing frame is correctly reported.""" # Inject a fault by shutting down the transport for writing — but not by # closing it because that would terminate the connection. self.transport.write_eof() + # Sending a pong will fail. # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: await self.connection.pong() - cause = raised.exception.__cause__ - self.assertEqual(str(cause), "Cannot call write() after write_eof()") - self.assertIsInstance(cause, RuntimeError) + + self.assertIsInstance(raised.exception.__cause__, RuntimeError) # Test safety nets — catching all exceptions in case of bugs. @@ -1216,9 +1261,7 @@ async def test_unexpected_failure_in_data_received(self, events_received): with self.assertRaises(ConnectionClosedError) as raised: await self.connection.recv() - exc = raised.exception - self.assertEqual(str(exc), "no close frame received or sent") - self.assertIsInstance(exc.__cause__, AssertionError) + self.assertIsInstance(raised.exception.__cause__, AssertionError) # Inject a fault in a random call in send_context(). # This test is tightly coupled to the implementation. @@ -1230,9 +1273,7 @@ async def test_unexpected_failure_in_send_context(self, send_text): with self.assertRaises(ConnectionClosedError) as raised: await self.connection.send("😀") - exc = raised.exception - self.assertEqual(str(exc), "no close frame received or sent") - self.assertIsInstance(exc.__cause__, AssertionError) + self.assertIsInstance(raised.exception.__cause__, AssertionError) # Test broadcast. @@ -1303,7 +1344,7 @@ async def test_broadcast_skips_closing_connection(self): async def test_broadcast_skips_connection_with_send_blocked(self): """broadcast logs a warning when a connection is blocked in send.""" - gate = asyncio.get_running_loop().create_future() + gate = self.loop.create_future() async def fragments(): yield "⏳" @@ -1330,7 +1371,7 @@ async def fragments(): ) async def test_broadcast_reports_connection_with_send_blocked(self): """broadcast raises exceptions for connections blocked in send.""" - gate = asyncio.get_running_loop().create_future() + gate = self.loop.create_future() async def fragments(): yield "⏳" diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index a90788d02..340aa00a8 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -267,6 +267,7 @@ async def test_get_iter_fragmented_text_message_not_received_yet(self): self.assertEqual(await anext(iterator), "f") self.assembler.put(Frame(OP_CONT, b"\xa9")) self.assertEqual(await anext(iterator), "é") + await iterator.aclose() async def test_get_iter_fragmented_binary_message_not_received_yet(self): """get_iter yields a fragmented binary message when it is received.""" @@ -277,6 +278,7 @@ async def test_get_iter_fragmented_binary_message_not_received_yet(self): self.assertEqual(await anext(iterator), b"e") self.assembler.put(Frame(OP_CONT, b"a")) self.assertEqual(await anext(iterator), b"a") + await iterator.aclose() async def test_get_iter_fragmented_text_message_being_received(self): """get_iter yields a fragmented text message that is partially received.""" @@ -287,6 +289,7 @@ async def test_get_iter_fragmented_text_message_being_received(self): self.assertEqual(await anext(iterator), "f") self.assembler.put(Frame(OP_CONT, b"\xa9")) self.assertEqual(await anext(iterator), "é") + await iterator.aclose() async def test_get_iter_fragmented_binary_message_being_received(self): """get_iter yields a fragmented binary message that is partially received.""" @@ -297,6 +300,7 @@ async def test_get_iter_fragmented_binary_message_being_received(self): self.assertEqual(await anext(iterator), b"e") self.assembler.put(Frame(OP_CONT, b"a")) self.assertEqual(await anext(iterator), b"a") + await iterator.aclose() async def test_get_iter_encoded_text_message(self): """get_iter yields a text message without UTF-8 decoding.""" @@ -334,6 +338,8 @@ async def test_get_iter_resumes_reading(self): await anext(iterator) self.resume.assert_called_once_with() + await iterator.aclose() + async def test_get_iter_does_not_resume_reading(self): """get_iter does not resume reading when the low-water mark is unset.""" self.assembler.low = None @@ -345,6 +351,7 @@ async def test_get_iter_does_not_resume_reading(self): await anext(iterator) await anext(iterator) await anext(iterator) + await iterator.aclose() self.resume.assert_not_called() @@ -467,7 +474,7 @@ async def test_get_iter_queued_fragmented_message_after_close(self): self.assertEqual(fragments, [b"t", b"e", b"a"]) async def test_get_partially_queued_fragmented_message_after_close(self): - """get raises EOF on a partial fragmented message after close is called.""" + """get raises EOFError on a partial fragmented message after close is called.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.close() diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index a5aee35bb..157aa2056 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -6,7 +6,7 @@ import time import unittest import uuid -from unittest.mock import patch +from unittest.mock import Mock, patch from websockets.exceptions import ( ConcurrencyError, @@ -489,8 +489,14 @@ def test_close_explicit_code_reason(self): def test_close_waits_for_close_frame(self): """close waits for a close frame (then EOF) before returning.""" + t0 = time.time() with self.delay_frames_rcvd(MS): self.connection.close() + t1 = time.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() @@ -504,8 +510,14 @@ def test_close_waits_for_connection_closed(self): if self.LOCAL is SERVER: self.skipTest("only relevant on the client-side") + t0 = time.time() with self.delay_eof_rcvd(MS): self.connection.close() + t1 = time.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() @@ -516,8 +528,14 @@ def test_close_waits_for_connection_closed(self): def test_close_timeout_waiting_for_close_frame(self): """close times out if no close frame is received.""" + t0 = time.time() with self.drop_frames_rcvd(), self.drop_eof_rcvd(): self.connection.close() + t1 = time.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.ABNORMAL_CLOSURE) + self.assertGreater(t1 - t0, 2 * MS) with self.assertRaises(ConnectionClosedError) as raised: self.connection.recv() @@ -531,8 +549,14 @@ def test_close_timeout_waiting_for_connection_closed(self): if self.LOCAL is SERVER: self.skipTest("only relevant on the client-side") + t0 = time.time() with self.drop_eof_rcvd(): self.connection.close() + t1 = time.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, 2 * MS) with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() @@ -548,13 +572,9 @@ def test_close_preserves_queued_messages(self): self.connection.close() self.assertEqual(self.connection.recv(), "😀") - with self.assertRaises(ConnectionClosedOK) as raised: + with self.assertRaises(ConnectionClosedOK): self.connection.recv() - exc = raised.exception - self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") - self.assertIsNone(exc.__cause__) - def test_close_idempotency(self): """close does nothing if the connection is already closed.""" self.connection.close() @@ -622,10 +642,10 @@ def closer(): exit_gate.set() def fragments(): - yield "😀" + yield "⏳" close_gate.set() exit_gate.wait() - yield "😀" + yield "⌛️" close_thread = threading.Thread(target=closer) close_thread.start() @@ -665,38 +685,38 @@ def test_ping_explicit_binary(self): def test_acknowledge_ping(self): """ping is acknowledged by a pong with the same payload.""" with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = self.connection.ping("this") + pong_received = self.connection.ping("this") self.remote_connection.pong("this") - self.assertTrue(pong_waiter.wait(MS)) + self.assertTrue(pong_received.wait(MS)) def test_acknowledge_ping_non_matching_pong(self): """ping isn't acknowledged by a pong with a different payload.""" with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = self.connection.ping("this") + pong_received = self.connection.ping("this") self.remote_connection.pong("that") - self.assertFalse(pong_waiter.wait(MS)) + self.assertFalse(pong_received.wait(MS)) def test_acknowledge_previous_ping(self): """ping is acknowledged by a pong for as a later ping.""" with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = self.connection.ping("this") + pong_received = self.connection.ping("this") self.connection.ping("that") self.remote_connection.pong("that") - self.assertTrue(pong_waiter.wait(MS)) + self.assertTrue(pong_received.wait(MS)) def test_acknowledge_ping_on_close(self): """ping with ack_on_close is acknowledged when the connection is closed.""" with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter_ack_on_close = self.connection.ping("this", ack_on_close=True) - pong_waiter = self.connection.ping("that") + pong_received_ack_on_close = self.connection.ping("this", ack_on_close=True) + pong_received = self.connection.ping("that") self.connection.close() - self.assertTrue(pong_waiter_ack_on_close.wait(MS)) - self.assertFalse(pong_waiter.wait(MS)) + self.assertTrue(pong_received_ack_on_close.wait(MS)) + self.assertFalse(pong_received.wait(MS)) def test_ping_duplicate_payload(self): """ping rejects the same payload until receiving the pong.""" with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = self.connection.ping("idem") + pong_received = self.connection.ping("idem") with self.assertRaises(ConcurrencyError) as raised: self.connection.ping("idem") @@ -706,7 +726,7 @@ def test_ping_duplicate_payload(self): ) self.remote_connection.pong("idem") - self.assertTrue(pong_waiter.wait(MS)) + self.assertTrue(pong_received.wait(MS)) self.connection.ping("idem") # doesn't raise an exception @@ -742,7 +762,7 @@ def test_pong_unsupported_type(self): @patch("random.getrandbits", return_value=1918987876) def test_keepalive(self, getrandbits): """keepalive sends pings at ping_interval and measures latency.""" - self.connection.ping_interval = 4 * MS + self.connection.ping_interval = 3 * MS self.connection.start_keepalive() self.assertIsNotNone(self.connection.keepalive_thread) self.assertEqual(self.connection.latency, 0) @@ -796,6 +816,7 @@ def test_keepalive_terminates_while_sleeping(self): """keepalive task terminates while waiting to send a ping.""" self.connection.ping_interval = 3 * MS self.connection.start_keepalive() + self.assertTrue(self.connection.keepalive_thread.is_alive()) time.sleep(MS) self.connection.close() self.connection.keepalive_thread.join(MS) @@ -803,8 +824,9 @@ def test_keepalive_terminates_while_sleeping(self): def test_keepalive_terminates_when_sending_ping_fails(self): """keepalive task terminates when sending a ping fails.""" - self.connection.ping_interval = 1 * MS + self.connection.ping_interval = MS self.connection.start_keepalive() + self.assertTrue(self.connection.keepalive_thread.is_alive()) with self.drop_eof_rcvd(), self.drop_frames_rcvd(): self.connection.close() self.assertFalse(self.connection.keepalive_thread.is_alive()) @@ -827,14 +849,13 @@ def test_keepalive_terminates_while_waiting_for_pong(self): def test_keepalive_reports_errors(self): """keepalive reports unexpected errors in logs.""" self.connection.ping_interval = 2 * MS - with self.drop_frames_rcvd(): - self.connection.start_keepalive() - # 2 ms: keepalive() sends a ping frame. - # 2.x ms: a pong frame is dropped. - with self.assertLogs("websockets", logging.ERROR) as logs: - with patch("threading.Event.wait", side_effect=Exception("BOOM")): - time.sleep(3 * MS) - # Exiting the context manager sleeps for 1 ms. + self.connection.start_keepalive() + # Inject a fault when waiting to receive a pong. + with self.assertLogs("websockets", logging.ERROR) as logs: + with patch("threading.Event.wait", side_effect=Exception("BOOM")): + # 2 ms: keepalive() sends a ping frame. + # 2.x ms: a pong frame is dropped. + time.sleep(3 * MS) self.assertEqual( [record.getMessage() for record in logs.records], ["keepalive ping failed"], @@ -848,11 +869,8 @@ def test_keepalive_reports_errors(self): def test_close_timeout(self): """close_timeout parameter configures close timeout.""" - socket_, remote_socket = socket.socketpair() - self.addCleanup(socket_.close) - self.addCleanup(remote_socket.close) connection = Connection( - socket_, + Mock(spec=socket.socket), Protocol(self.LOCAL), close_timeout=42 * MS, ) @@ -860,11 +878,8 @@ def test_close_timeout(self): def test_max_queue(self): """max_queue configures high-water mark of frames buffer.""" - socket_, remote_socket = socket.socketpair() - self.addCleanup(socket_.close) - self.addCleanup(remote_socket.close) connection = Connection( - socket_, + Mock(spec=socket.socket), Protocol(self.LOCAL), max_queue=4, ) @@ -872,11 +887,8 @@ def test_max_queue(self): def test_max_queue_none(self): """max_queue disables high-water mark of frames buffer.""" - socket_, remote_socket = socket.socketpair() - self.addCleanup(socket_.close) - self.addCleanup(remote_socket.close) connection = Connection( - socket_, + Mock(spec=socket.socket), Protocol(self.LOCAL), max_queue=None, ) @@ -885,11 +897,8 @@ def test_max_queue_none(self): def test_max_queue_tuple(self): """max_queue configures high-water and low-water marks of frames buffer.""" - socket_, remote_socket = socket.socketpair() - self.addCleanup(socket_.close) - self.addCleanup(remote_socket.close) connection = Connection( - socket_, + Mock(spec=socket.socket), Protocol(self.LOCAL), max_queue=(4, 2), ) @@ -960,11 +969,13 @@ def test_writing_in_recv_events_fails(self): # Inject a fault by shutting down the socket for writing — but not by # closing it because that would terminate the connection. self.connection.socket.shutdown(socket.SHUT_WR) + # Receive a ping. Responding with a pong will fail. self.remote_connection.ping() # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: self.connection.recv() + self.assertIsInstance(raised.exception.__cause__, BrokenPipeError) def test_writing_in_send_context_fails(self): @@ -972,10 +983,12 @@ def test_writing_in_send_context_fails(self): # Inject a fault by shutting down the socket for writing — but not by # closing it because that would terminate the connection. self.connection.socket.shutdown(socket.SHUT_WR) + # Sending a pong will fail. # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: self.connection.pong() + self.assertIsInstance(raised.exception.__cause__, BrokenPipeError) # Test safety nets — catching all exceptions in case of bugs. @@ -991,9 +1004,7 @@ def test_unexpected_failure_in_recv_events(self, events_received): with self.assertRaises(ConnectionClosedError) as raised: self.connection.recv() - exc = raised.exception - self.assertEqual(str(exc), "no close frame received or sent") - self.assertIsInstance(exc.__cause__, AssertionError) + self.assertIsInstance(raised.exception.__cause__, AssertionError) # Inject a fault in a random call in send_context(). # This test is tightly coupled to the implementation. @@ -1005,9 +1016,7 @@ def test_unexpected_failure_in_send_context(self, send_text): with self.assertRaises(ConnectionClosedError) as raised: self.connection.send("😀") - exc = raised.exception - self.assertEqual(str(exc), "no close frame received or sent") - self.assertIsInstance(exc.__cause__, AssertionError) + self.assertIsInstance(raised.exception.__cause__, AssertionError) class ServerConnectionTests(ClientConnectionTests): diff --git a/tests/trio/__init__.py b/tests/trio/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/trio/connection.py b/tests/trio/connection.py new file mode 100644 index 000000000..2a7f2aa07 --- /dev/null +++ b/tests/trio/connection.py @@ -0,0 +1,116 @@ +import contextlib + +import trio + +from websockets.trio.connection import Connection + + +class InterceptingConnection(Connection): + """ + Connection subclass that can intercept outgoing packets. + + By interfacing with this connection, we simulate network conditions + affecting what the component being tested receives during a test. + + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.stream = InterceptingStream(self.stream) + + @contextlib.contextmanager + def delay_frames_sent(self, delay): + """ + Add a delay before sending frames. + + Delays cumulate: they're added before every frame or before EOF. + + """ + assert self.stream.delay_send_all is None + self.stream.delay_send_all = delay + try: + yield + finally: + self.stream.delay_send_all = None + + @contextlib.contextmanager + def delay_eof_sent(self, delay): + """ + Add a delay before sending EOF. + + Delays cumulate: they're added before every frame or before EOF. + + """ + assert self.stream.delay_send_eof is None + self.stream.delay_send_eof = delay + try: + yield + finally: + self.stream.delay_send_eof = None + + @contextlib.contextmanager + def drop_frames_sent(self): + """ + Prevent frames from being sent. + + Since TCP is reliable, sending frames or EOF afterwards is unrealistic. + + """ + assert not self.stream.drop_send_all + self.stream.drop_send_all = True + try: + yield + finally: + self.stream.drop_send_all = False + + @contextlib.contextmanager + def drop_eof_sent(self): + """ + Prevent EOF from being sent. + + Since TCP is reliable, sending frames or EOF afterwards is unrealistic. + + """ + assert not self.stream.drop_send_eof + self.stream.drop_send_eof = True + try: + yield + finally: + self.stream.drop_send_eof = False + + +class InterceptingStream: + """ + Stream wrapper that intercepts calls to ``send_all()`` and ``send_eof()``. + + This is coupled to the implementation, which relies on these two methods. + + """ + + # We cannot delay EOF with trio's virtual streams because close_hook is + # synchronous. Adopt the same approach as in the other implementations. + + def __init__(self, stream): + self.stream = stream + self.delay_send_all = None + self.delay_send_eof = None + self.drop_send_all = False + self.drop_send_eof = False + + def __getattr__(self, name): + return getattr(self.stream, name) + + async def send_all(self, data): + if self.delay_send_all is not None: + await trio.sleep(self.delay_send_all) + if not self.drop_send_all: + await self.stream.send_all(data) + + async def send_eof(self): + if self.delay_send_eof is not None: + await trio.sleep(self.delay_send_eof) + if not self.drop_send_eof: + await self.stream.send_eof() + + +trio.abc.HalfCloseableStream.register(InterceptingStream) diff --git a/tests/trio/test_connection.py b/tests/trio/test_connection.py new file mode 100644 index 000000000..8c613c700 --- /dev/null +++ b/tests/trio/test_connection.py @@ -0,0 +1,1253 @@ +import contextlib +import logging +import uuid +from unittest.mock import patch + +import trio.testing + +from websockets.asyncio.compatibility import TimeoutError, aiter, anext +from websockets.exceptions import ( + ConcurrencyError, + ConnectionClosedError, + ConnectionClosedOK, +) +from websockets.frames import CloseCode, Frame, Opcode +from websockets.protocol import CLIENT, SERVER, Protocol, State +from websockets.trio.connection import * + +from ..asyncio.utils import alist +from ..protocol import RecordingProtocol +from ..utils import MS, AssertNoLogsMixin +from .connection import InterceptingConnection +from .utils import IsolatedTrioTestCase + + +# Connection implements symmetrical behavior between clients and servers. +# All tests run on the client side and the server side to validate this. + + +class ClientConnectionTests(AssertNoLogsMixin, IsolatedTrioTestCase): + LOCAL = CLIENT + REMOTE = SERVER + + async def asyncSetUp(self): + stream, remote_stream = trio.testing.memory_stream_pair() + protocol = Protocol(self.LOCAL) + remote_protocol = RecordingProtocol(self.REMOTE) + self.connection = Connection( + self.nursery, stream, protocol, close_timeout=2 * MS + ) + self.remote_connection = InterceptingConnection( + self.nursery, remote_stream, remote_protocol + ) + + async def asyncTearDown(self): + await self.remote_connection.close() + await self.connection.close() + + # Test helpers built upon RecordingProtocol and InterceptingConnection. + + async def assertFrameSent(self, frame): + """Check that a single frame was sent.""" + await trio.testing.wait_all_tasks_blocked() + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), [frame]) + + async def assertFramesSent(self, frames): + """Check that several frames were sent.""" + await trio.testing.wait_all_tasks_blocked() + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), frames) + + async def assertNoFrameSent(self): + """Check that no frame was sent.""" + await trio.testing.wait_all_tasks_blocked() + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), []) + + @contextlib.asynccontextmanager + async def delay_frames_rcvd(self, delay): + """Delay frames before they're received by the connection.""" + with self.remote_connection.delay_frames_sent(delay): + yield + await trio.testing.wait_all_tasks_blocked() + + @contextlib.asynccontextmanager + async def delay_eof_rcvd(self, delay): + """Delay EOF before it's received by the connection.""" + with self.remote_connection.delay_eof_sent(delay): + yield + await trio.testing.wait_all_tasks_blocked() + + @contextlib.asynccontextmanager + async def drop_frames_rcvd(self): + """Drop frames before they're received by the connection.""" + with self.remote_connection.drop_frames_sent(): + yield + await trio.testing.wait_all_tasks_blocked() + + @contextlib.asynccontextmanager + async def drop_eof_rcvd(self): + """Drop EOF before it's received by the connection.""" + with self.remote_connection.drop_eof_sent(): + yield + await trio.testing.wait_all_tasks_blocked() + + # Test __aenter__ and __aexit__. + + async def test_aenter(self): + """__aenter__ returns the connection itself.""" + async with self.connection as connection: + self.assertIs(connection, self.connection) + + async def test_aexit(self): + """__aexit__ closes the connection with code 1000.""" + async with self.connection: + await self.assertNoFrameSent() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + async def test_exit_with_exception(self): + """__exit__ with an exception closes the connection with code 1011.""" + with self.assertRaises(RuntimeError): + async with self.connection: + raise RuntimeError + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xf3")) + + # Test __aiter__. + + async def test_aiter_text(self): + """__aiter__ yields text messages.""" + iterator = aiter(self.connection) + await self.remote_connection.send("😀") + self.assertEqual(await anext(iterator), "😀") + await self.remote_connection.send("😀") + self.assertEqual(await anext(iterator), "😀") + await iterator.aclose() + + async def test_aiter_binary(self): + """__aiter__ yields binary messages.""" + iterator = aiter(self.connection) + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(iterator), b"\x01\x02\xfe\xff") + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(iterator), b"\x01\x02\xfe\xff") + await iterator.aclose() + + async def test_aiter_mixed(self): + """__aiter__ yields a mix of text and binary messages.""" + iterator = aiter(self.connection) + await self.remote_connection.send("😀") + self.assertEqual(await anext(iterator), "😀") + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(iterator), b"\x01\x02\xfe\xff") + await iterator.aclose() + + async def test_aiter_connection_closed_ok(self): + """__aiter__ terminates after a normal closure.""" + iterator = aiter(self.connection) + await self.remote_connection.close() + with self.assertRaises(StopAsyncIteration): + await anext(iterator) + await iterator.aclose() + + async def test_aiter_connection_closed_error(self): + """__aiter__ raises ConnectionClosedError after an error.""" + iterator = aiter(self.connection) + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await anext(iterator) + await iterator.aclose() + + # Test recv. + + async def test_recv_text(self): + """recv receives a text message.""" + await self.remote_connection.send("😀") + self.assertEqual(await self.connection.recv(), "😀") + + async def test_recv_binary(self): + """recv receives a binary message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff") + + async def test_recv_text_as_bytes(self): + """recv receives a text message as bytes.""" + await self.remote_connection.send("😀") + self.assertEqual(await self.connection.recv(decode=False), "😀".encode()) + + async def test_recv_binary_as_text(self): + """recv receives a binary message as a str.""" + await self.remote_connection.send("😀".encode()) + self.assertEqual(await self.connection.recv(decode=True), "😀") + + async def test_recv_fragmented_text(self): + """recv receives a fragmented text message.""" + await self.remote_connection.send(["😀", "😀"]) + self.assertEqual(await self.connection.recv(), "😀😀") + + async def test_recv_fragmented_binary(self): + """recv receives a fragmented binary message.""" + await self.remote_connection.send([b"\x01\x02", b"\xfe\xff"]) + self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff") + + async def test_recv_connection_closed_ok(self): + """recv raises ConnectionClosedOK after a normal closure.""" + await self.remote_connection.close() + with self.assertRaises(ConnectionClosedOK): + await self.connection.recv() + + async def test_recv_connection_closed_error(self): + """recv raises ConnectionClosedError after an error.""" + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await self.connection.recv() + + async def test_recv_non_utf8_text(self): + """recv receives a non-UTF-8 text message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) + with self.assertRaises(ConnectionClosedError): + await self.connection.recv() + await self.assertFrameSent( + Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") + ) + + async def test_recv_during_recv(self): + """recv raises ConcurrencyError when called concurrently.""" + self.nursery.start_soon(self.connection.recv) + await trio.testing.wait_all_tasks_blocked() + + with self.assertRaises(ConcurrencyError) as raised: + await self.connection.recv() + self.assertEqual( + str(raised.exception), + "cannot call recv while another coroutine " + "is already running recv or recv_streaming", + ) + + await self.remote_connection.send("") + + async def test_recv_during_recv_streaming(self): + """recv raises ConcurrencyError when called concurrently with recv_streaming.""" + self.nursery.start_soon(alist, self.connection.recv_streaming()) + await trio.testing.wait_all_tasks_blocked() + + with self.assertRaises(ConcurrencyError) as raised: + await self.connection.recv() + self.assertEqual( + str(raised.exception), + "cannot call recv while another coroutine " + "is already running recv or recv_streaming", + ) + + await self.remote_connection.send("") + + async def test_recv_cancellation_before_receiving(self): + """recv can be canceled before receiving a message.""" + with trio.move_on_after(MS): + await self.connection.recv() + + # Running recv again receives the next message. + await self.remote_connection.send("😀") + self.assertEqual(await self.connection.recv(), "😀") + + async def test_recv_cancellation_while_receiving(self): + """recv can be canceled while receiving a fragmented message.""" + gate = trio.Event() + + async def fragments(): + yield "⏳" + await gate.wait() + yield "⌛️" + + async def send_fragments(): + await self.remote_connection.send(fragments()) + + self.nursery.start_soon(send_fragments) + await trio.testing.wait_all_tasks_blocked() + + with trio.move_on_after(MS): + await self.connection.recv() + + gate.set() + + # Running recv again receives the complete message. + self.assertEqual(await self.connection.recv(), "⏳⌛️") + + # Test recv_streaming. + + async def test_recv_streaming_text(self): + """recv_streaming receives a text message.""" + await self.remote_connection.send("😀") + self.assertEqual( + await alist(self.connection.recv_streaming()), + ["😀"], + ) + + async def test_recv_streaming_binary(self): + """recv_streaming receives a binary message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual( + await alist(self.connection.recv_streaming()), + [b"\x01\x02\xfe\xff"], + ) + + async def test_recv_streaming_text_as_bytes(self): + """recv_streaming receives a text message as bytes.""" + await self.remote_connection.send("😀") + self.assertEqual( + await alist(self.connection.recv_streaming(decode=False)), + ["😀".encode()], + ) + + async def test_recv_streaming_binary_as_str(self): + """recv_streaming receives a binary message as a str.""" + await self.remote_connection.send("😀".encode()) + self.assertEqual( + await alist(self.connection.recv_streaming(decode=True)), + ["😀"], + ) + + async def test_recv_streaming_fragmented_text(self): + """recv_streaming receives a fragmented text message.""" + await self.remote_connection.send(["😀", "😀"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_recv_streaming_fragmented_binary(self): + """recv_streaming receives a fragmented binary message.""" + await self.remote_connection.send([b"\x01\x02", b"\xfe\xff"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + async def test_recv_streaming_connection_closed_ok(self): + """recv_streaming raises ConnectionClosedOK after a normal closure.""" + await self.remote_connection.close() + with self.assertRaises(ConnectionClosedOK): + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + + async def test_recv_streaming_connection_closed_error(self): + """recv_streaming raises ConnectionClosedError after an error.""" + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + + async def test_recv_streaming_non_utf8_text(self): + """recv_streaming receives a non-UTF-8 text message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) + with self.assertRaises(ConnectionClosedError): + await alist(self.connection.recv_streaming()) + await self.assertFrameSent( + Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") + ) + + async def test_recv_streaming_during_recv(self): + """recv_streaming raises ConcurrencyError when called concurrently with recv.""" + self.nursery.start_soon(self.connection.recv) + await trio.testing.wait_all_tasks_blocked() + + with self.assertRaises(ConcurrencyError) as raised: + await alist(self.connection.recv_streaming()) + self.assertEqual( + str(raised.exception), + "cannot call recv_streaming while another coroutine " + "is already running recv or recv_streaming", + ) + + await self.remote_connection.send("") + + async def test_recv_streaming_during_recv_streaming(self): + """recv_streaming raises ConcurrencyError when called concurrently.""" + self.nursery.start_soon(alist, self.connection.recv_streaming()) + await trio.testing.wait_all_tasks_blocked() + + with self.assertRaises(ConcurrencyError) as raised: + await alist(self.connection.recv_streaming()) + self.assertEqual( + str(raised.exception), + r"cannot call recv_streaming while another coroutine " + r"is already running recv or recv_streaming", + ) + + await self.remote_connection.send("") + + async def test_recv_streaming_cancellation_before_receiving(self): + """recv_streaming can be canceled before receiving a message.""" + with trio.move_on_after(MS): + await alist(self.connection.recv_streaming()) + + # Running recv_streaming again receives the next message. + await self.remote_connection.send(["😀", "😀"]) + self.assertEqual( + await alist(self.connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_recv_streaming_cancellation_while_receiving(self): + """recv_streaming cannot be canceled while receiving a fragmented message.""" + gate = trio.Event() + + async def fragments(): + yield "⏳" + await gate.wait() + yield "⌛️" + + async def send_fragments(): + iterator = fragments() + with self.assertRaises(ConnectionClosedError): + await self.remote_connection.send(iterator) + await iterator.aclose() + + self.nursery.start_soon(send_fragments) + await trio.testing.wait_all_tasks_blocked() + + with trio.move_on_after(MS): + await alist(self.connection.recv_streaming()) + + gate.set() + + # Running recv_streaming again fails. + with self.assertRaises(ConcurrencyError): + await alist(self.connection.recv_streaming()) + + # Test send. + + async def test_send_text(self): + """send sends a text message.""" + await self.connection.send("😀") + self.assertEqual(await self.remote_connection.recv(), "😀") + + async def test_send_binary(self): + """send sends a binary message.""" + await self.connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await self.remote_connection.recv(), b"\x01\x02\xfe\xff") + + async def test_send_binary_from_str(self): + """send sends a binary message from a str.""" + await self.connection.send("😀", text=False) + self.assertEqual(await self.remote_connection.recv(), "😀".encode()) + + async def test_send_text_from_bytes(self): + """send sends a text message from bytes.""" + await self.connection.send("😀".encode(), text=True) + self.assertEqual(await self.remote_connection.recv(), "😀") + + async def test_send_fragmented_text(self): + """send sends a fragmented text message.""" + await self.connection.send(["😀", "😀"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_send_fragmented_binary(self): + """send sends a fragmented binary message.""" + await self.connection.send([b"\x01\x02", b"\xfe\xff"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + async def test_send_fragmented_binary_from_str(self): + """send sends a fragmented binary message from a str.""" + await self.connection.send(["😀", "😀"], text=False) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀".encode(), "😀".encode(), b""], + ) + + async def test_send_fragmented_text_from_bytes(self): + """send sends a fragmented text message from bytes.""" + await self.connection.send(["😀".encode(), "😀".encode()], text=True) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_send_async_fragmented_text(self): + """send sends a fragmented text message asynchronously.""" + + async def fragments(): + yield "😀" + yield "😀" + + await self.connection.send(fragments()) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_send_async_fragmented_binary(self): + """send sends a fragmented binary message asynchronously.""" + + async def fragments(): + yield b"\x01\x02" + yield b"\xfe\xff" + + await self.connection.send(fragments()) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + async def test_send_async_fragmented_binary_from_str(self): + """send sends a fragmented binary message from a str asynchronously.""" + + async def fragments(): + yield "😀" + yield "😀" + + await self.connection.send(fragments(), text=False) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀".encode(), "😀".encode(), b""], + ) + + async def test_send_async_fragmented_text_from_bytes(self): + """send sends a fragmented text message from bytes asynchronously.""" + + async def fragments(): + yield "😀".encode() + yield "😀".encode() + + await self.connection.send(fragments(), text=True) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_send_connection_closed_ok(self): + """send raises ConnectionClosedOK after a normal closure.""" + await self.remote_connection.close() + with self.assertRaises(ConnectionClosedOK): + await self.connection.send("😀") + + async def test_send_connection_closed_error(self): + """send raises ConnectionClosedError after an error.""" + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await self.connection.send("😀") + + async def test_send_during_send_async(self): + """send waits for a previous call to send to complete.""" + # This test fails if the guard with send_in_progress is removed + # from send() in the case when message is an AsyncIterable. + gate = trio.Event() + + async def fragments(): + yield "⏳" + await gate.wait() + yield "⌛️" + + async def send_fragments(): + await self.connection.send(fragments()) + + self.nursery.start_soon(send_fragments) + await trio.testing.wait_all_tasks_blocked() + await self.assertFrameSent( + Frame(Opcode.TEXT, "⏳".encode(), fin=False), + ) + + self.nursery.start_soon(self.connection.send, "✅") + await trio.testing.wait_all_tasks_blocked() + await self.assertNoFrameSent() + + gate.set() + await trio.testing.wait_all_tasks_blocked() + await self.assertFramesSent( + [ + Frame(Opcode.CONT, "⌛️".encode(), fin=False), + Frame(Opcode.CONT, b"", fin=True), + Frame(Opcode.TEXT, "✅".encode()), + ] + ) + + async def test_send_empty_iterable(self): + """send does nothing when called with an empty iterable.""" + await self.connection.send([]) + await self.connection.close() + self.assertEqual(await alist(self.remote_connection), []) + + async def test_send_mixed_iterable(self): + """send raises TypeError when called with an iterable of inconsistent types.""" + with self.assertRaises(TypeError): + await self.connection.send(["😀", b"\xfe\xff"]) + + async def test_send_unsupported_iterable(self): + """send raises TypeError when called with an iterable of unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.send([None]) + + async def test_send_empty_async_iterable(self): + """send does nothing when called with an empty async iterable.""" + + async def fragments(): + return + yield # pragma: no cover + + await self.connection.send(fragments()) + await self.connection.close() + self.assertEqual(await alist(self.remote_connection), []) + + async def test_send_mixed_async_iterable(self): + """send raises TypeError when called with an iterable of inconsistent types.""" + + async def fragments(): + yield "😀" + yield b"\xfe\xff" + + iterator = fragments() + with self.assertRaises(TypeError): + await self.connection.send(iterator) + await iterator.aclose() + + async def test_send_unsupported_async_iterable(self): + """send raises TypeError when called with an iterable of unsupported type.""" + + async def fragments(): + yield None + + iterator = fragments() + with self.assertRaises(TypeError): + await self.connection.send(iterator) + await iterator.aclose() + + async def test_send_dict(self): + """send raises TypeError when called with a dict.""" + with self.assertRaises(TypeError): + await self.connection.send({"type": "object"}) + + async def test_send_unsupported_type(self): + """send raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.send(None) + + # Test close. + + async def test_close(self): + """close sends a close frame.""" + await self.connection.close() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + async def test_close_explicit_code_reason(self): + """close sends a close frame with a given code and reason.""" + await self.connection.close(CloseCode.GOING_AWAY, "bye!") + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe9bye!")) + + async def test_close_waits_for_close_frame(self): + """close waits for a close frame (then EOF) before returning.""" + t0 = trio.current_time() + async with self.delay_frames_rcvd(MS): + await self.connection.close() + t1 = trio.current_time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_waits_for_connection_closed(self): + """close waits for EOF before returning.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + t0 = trio.current_time() + async with self.delay_eof_rcvd(MS): + await self.connection.close() + t1 = trio.current_time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_no_timeout_waits_for_close_frame(self): + """close without timeout waits for a close frame (then EOF) before returning.""" + self.connection.close_timeout = None + + t0 = trio.current_time() + async with self.delay_frames_rcvd(MS): + await self.connection.close() + t1 = trio.current_time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_no_timeout_waits_for_connection_closed(self): + """close without timeout waits for EOF before returning.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + self.connection.close_timeout = None + + t0 = trio.current_time() + async with self.delay_eof_rcvd(MS): + await self.connection.close() + t1 = trio.current_time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_timeout_waiting_for_close_frame(self): + """close times out if no close frame is received.""" + t0 = trio.current_time() + async with self.drop_eof_rcvd(), self.drop_frames_rcvd(): + await self.connection.close() + t1 = trio.current_time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.ABNORMAL_CLOSURE) + self.assertGreater(t1 - t0, 2 * MS) + + exc = self.connection.protocol.close_exc + self.assertEqual(str(exc), "sent 1000 (OK); no close frame received") + # TODO + # self.assertIsInstance(exc.__cause__, TimeoutError) + + async def test_close_timeout_waiting_for_connection_closed(self): + """close times out if EOF isn't received.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + t0 = trio.current_time() + async with self.drop_eof_rcvd(): + await self.connection.close() + t1 = trio.current_time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, 2 * MS) + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsInstance(exc.__cause__, TimeoutError) + + async def test_close_preserves_queued_messages(self): + """close preserves messages buffered in the assembler.""" + await self.remote_connection.send("😀") + await self.connection.close() + + self.assertEqual(await self.connection.recv(), "😀") + with self.assertRaises(ConnectionClosedOK): + await self.connection.recv() + + async def test_close_idempotency(self): + """close does nothing if the connection is already closed.""" + await self.connection.close() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + await self.connection.close() + await self.assertNoFrameSent() + + async def test_close_during_recv(self): + """close aborts recv when called concurrently with recv.""" + + async def closer(): + await trio.sleep(MS) + await self.connection.close() + + self.nursery.start_soon(closer) + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_during_send(self): + """close fails the connection when called concurrently with send.""" + close_gate = trio.Event() + exit_gate = trio.Event() + + async def closer(): + await close_gate.wait() + await trio.testing.wait_all_tasks_blocked() + await self.connection.close() + exit_gate.set() + + async def fragments(): + yield "⏳" + close_gate.set() + await exit_gate.wait() + yield "⌛️" + + self.nursery.start_soon(closer) + + iterator = fragments() + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.send(iterator) + await iterator.aclose() + + exc = raised.exception + self.assertEqual( + str(exc), + "sent 1011 (internal error) close during fragmented message; " + "no close frame received", + ) + self.assertIsNone(exc.__cause__) + + # Test wait_closed. + + async def test_wait_closed(self): + """wait_closed waits for the connection to close.""" + closed = trio.Event() + + async def closer(): + await self.connection.wait_closed() + closed.set() + + self.nursery.start_soon(closer) + await trio.testing.wait_all_tasks_blocked() + self.assertFalse(closed.is_set()) + + await self.connection.close() + await trio.testing.wait_all_tasks_blocked() + self.assertTrue(closed.is_set()) + + # Test ping. + + @patch("random.getrandbits", return_value=1918987876) + async def test_ping(self, getrandbits): + """ping sends a ping frame with a random payload.""" + await self.connection.ping() + getrandbits.assert_called_once_with(32) + await self.assertFrameSent(Frame(Opcode.PING, b"rand")) + + async def test_ping_explicit_text(self): + """ping sends a ping frame with a payload provided as text.""" + await self.connection.ping("ping") + await self.assertFrameSent(Frame(Opcode.PING, b"ping")) + + async def test_ping_explicit_binary(self): + """ping sends a ping frame with a payload provided as binary.""" + await self.connection.ping(b"ping") + await self.assertFrameSent(Frame(Opcode.PING, b"ping")) + + async def test_acknowledge_ping(self): + """ping is acknowledged by a pong with the same payload.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_received = await self.connection.ping("this") + await self.remote_connection.pong("this") + with trio.fail_after(MS): + await pong_received.wait() + + async def test_acknowledge_ping_non_matching_pong(self): + """ping isn't acknowledged by a pong with a different payload.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_received = await self.connection.ping("this") + await self.remote_connection.pong("that") + with self.assertRaises(trio.TooSlowError): + with trio.fail_after(MS): + await pong_received.wait() + + async def test_acknowledge_previous_ping(self): + """ping is acknowledged by a pong for a later ping.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_received = await self.connection.ping("this") + await self.connection.ping("that") + await self.remote_connection.pong("that") + with trio.fail_after(MS): + await pong_received.wait() + + async def test_acknowledge_ping_on_close(self): + """ping with ack_on_close is acknowledged when the connection is closed.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_received_ack_on_close = await self.connection.ping( + "this", ack_on_close=True + ) + pong_received = await self.connection.ping("that") + await self.connection.close() + with trio.fail_after(MS): + await pong_received_ack_on_close.wait() + with self.assertRaises(trio.TooSlowError): + with trio.fail_after(MS): + await pong_received.wait() + + async def test_ping_duplicate_payload(self): + """ping rejects the same payload until receiving the pong.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_received = await self.connection.ping("idem") + + with self.assertRaises(ConcurrencyError) as raised: + await self.connection.ping("idem") + self.assertEqual( + str(raised.exception), + "already waiting for a pong with the same data", + ) + + await self.remote_connection.pong("idem") + with trio.fail_after(MS): + await pong_received.wait() + + await self.connection.ping("idem") # doesn't raise an exception + + async def test_ping_unsupported_type(self): + """ping raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.ping([]) + + # Test pong. + + async def test_pong(self): + """pong sends a pong frame.""" + await self.connection.pong() + await self.assertFrameSent(Frame(Opcode.PONG, b"")) + + async def test_pong_explicit_text(self): + """pong sends a pong frame with a payload provided as text.""" + await self.connection.pong("pong") + await self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + + async def test_pong_explicit_binary(self): + """pong sends a pong frame with a payload provided as binary.""" + await self.connection.pong(b"pong") + await self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + + async def test_pong_unsupported_type(self): + """pong raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.pong([]) + + # Test keepalive. + + def keepalive_task_is_running(self): + return any( + task.name == "websockets.trio.connection.Connection.keepalive" + for task in self.nursery.child_tasks + ) + + @patch("random.getrandbits", return_value=1918987876) + async def test_keepalive(self, getrandbits): + """keepalive sends pings at ping_interval and measures latency.""" + self.connection.ping_interval = 3 * MS + self.connection.start_keepalive() + self.assertTrue(self.keepalive_task_is_running()) + self.assertEqual(self.connection.latency, 0) + # 3 ms: keepalive() sends a ping frame. + # 3.x ms: a pong frame is received. + await trio.sleep(4 * MS) + # 4 ms: check that the ping frame was sent. + await self.assertFrameSent(Frame(Opcode.PING, b"rand")) + self.assertGreater(self.connection.latency, 0) + self.assertLess(self.connection.latency, MS) + + async def test_disable_keepalive(self): + """keepalive is disabled when ping_interval is None.""" + self.connection.ping_interval = None + self.connection.start_keepalive() + self.assertFalse(self.keepalive_task_is_running()) + + @patch("random.getrandbits", return_value=1918987876) + async def test_keepalive_times_out(self, getrandbits): + """keepalive closes the connection if ping_timeout elapses.""" + self.connection.ping_interval = 4 * MS + self.connection.ping_timeout = 2 * MS + async with self.drop_frames_rcvd(): + self.connection.start_keepalive() + # 4 ms: keepalive() sends a ping frame. + # 4.x ms: a pong frame is dropped. + await trio.sleep(5 * MS) + # 6 ms: no pong frame is received; the connection is closed. + await trio.sleep(2 * MS) + # 7 ms: check that the connection is closed. + self.assertEqual(self.connection.state, State.CLOSED) + + @patch("random.getrandbits", return_value=1918987876) + async def test_keepalive_ignores_timeout(self, getrandbits): + """keepalive ignores timeouts if ping_timeout isn't set.""" + self.connection.ping_interval = 4 * MS + self.connection.ping_timeout = None + async with self.drop_frames_rcvd(): + self.connection.start_keepalive() + # 4 ms: keepalive() sends a ping frame. + # 4.x ms: a pong frame is dropped. + await trio.sleep(5 * MS) + # 6 ms: no pong frame is received; the connection remains open. + await trio.sleep(2 * MS) + # 7 ms: check that the connection is still open. + self.assertEqual(self.connection.state, State.OPEN) + + async def test_keepalive_terminates_while_sleeping(self): + """keepalive task terminates while waiting to send a ping.""" + self.connection.ping_interval = 3 * MS + self.connection.start_keepalive() + self.assertTrue(self.keepalive_task_is_running()) + await trio.testing.wait_all_tasks_blocked() + await self.connection.close() + await trio.testing.wait_all_tasks_blocked() + self.assertFalse(self.keepalive_task_is_running()) + + async def test_keepalive_terminates_when_sending_ping_fails(self): + """keepalive task terminates when sending a ping fails.""" + self.connection.ping_interval = MS + self.connection.start_keepalive() + self.assertTrue(self.keepalive_task_is_running()) + async with self.drop_eof_rcvd(), self.drop_frames_rcvd(): + await self.connection.close() + await trio.testing.wait_all_tasks_blocked() + self.assertFalse(self.keepalive_task_is_running()) + + async def test_keepalive_terminates_while_waiting_for_pong(self): + """keepalive task terminates while waiting to receive a pong.""" + self.connection.ping_interval = MS + self.connection.ping_timeout = 3 * MS + async with self.drop_frames_rcvd(): + self.connection.start_keepalive() + # 1 ms: keepalive() sends a ping frame. + # 1.x ms: a pong frame is dropped. + await trio.sleep(2 * MS) + # 2 ms: close the connection before ping_timeout elapses. + await self.connection.close() + await trio.testing.wait_all_tasks_blocked() + self.assertFalse(self.keepalive_task_is_running()) + + async def test_keepalive_reports_errors(self): + """keepalive reports unexpected errors in logs.""" + self.connection.ping_interval = 2 * MS + self.connection.start_keepalive() + # Inject a fault when waiting to receive a pong. + with self.assertLogs("websockets", logging.ERROR) as logs: + with patch("trio.Event.wait", side_effect=Exception("BOOM")): + # 2 ms: keepalive() sends a ping frame. + # 2.x ms: a pong frame is dropped. + await trio.sleep(3 * MS) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["keepalive ping failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) + + # Test parameters. + + async def test_close_timeout(self): + """close_timeout parameter configures close timeout.""" + stream, remote_stream = trio.testing.memory_stream_pair() + connection = Connection( + self.nursery, + stream, + Protocol(self.LOCAL), + close_timeout=42, + ) + self.assertEqual(connection.close_timeout, 42) + await remote_stream.aclose() + + async def test_max_queue(self): + """max_queue configures high-water mark of frames buffer.""" + stream, remote_stream = trio.testing.memory_stream_pair() + connection = Connection( + self.nursery, + stream, + Protocol(self.LOCAL), + max_queue=4, + ) + self.assertEqual(connection.recv_messages.high, 4) + await remote_stream.aclose() + + async def test_max_queue_none(self): + """max_queue disables high-water mark of frames buffer.""" + stream, remote_stream = trio.testing.memory_stream_pair() + connection = Connection( + self.nursery, + stream, + Protocol(self.LOCAL), + max_queue=None, + ) + self.assertEqual(connection.recv_messages.high, None) + self.assertEqual(connection.recv_messages.low, None) + await remote_stream.aclose() + + async def test_max_queue_tuple(self): + """max_queue configures high-water and low-water marks of frames buffer.""" + stream, remote_stream = trio.testing.memory_stream_pair() + connection = Connection( + self.nursery, + stream, + Protocol(self.LOCAL), + max_queue=(4, 2), + ) + self.assertEqual(connection.recv_messages.high, 4) + self.assertEqual(connection.recv_messages.low, 2) + await remote_stream.aclose() + + # Test attributes. + + async def test_id(self): + """Connection has an id attribute.""" + self.assertIsInstance(self.connection.id, uuid.UUID) + + async def test_logger(self): + """Connection has a logger attribute.""" + self.assertIsInstance(self.connection.logger, logging.LoggerAdapter) + + @contextlib.asynccontextmanager + async def get_server_and_client_streams(self): + listeners = await trio.open_tcp_listeners(0, host="127.0.0.1") + assert len(listeners) == 1 + listener = listeners[0] + client_stream = await trio.testing.open_stream_to_socket_listener(listener) + client_port = client_stream.socket.getsockname()[1] + server_stream = await listener.accept() + server_port = listener.socket.getsockname()[1] + try: + yield client_stream, server_stream, client_port, server_port + finally: + await server_stream.aclose() + await client_stream.aclose() + await listener.aclose() + + async def test_local_address(self): + """Connection provides a local_address attribute.""" + async with self.get_server_and_client_streams() as ( + client_stream, + server_stream, + client_port, + server_port, + ): + stream = {CLIENT: client_stream, SERVER: server_stream}[self.LOCAL] + port = {CLIENT: client_port, SERVER: server_port}[self.LOCAL] + connection = Connection(self.nursery, stream, Protocol(self.LOCAL)) + self.assertEqual(connection.local_address, ("127.0.0.1", port)) + + async def test_remote_address(self): + """Connection provides a remote_address attribute.""" + async with self.get_server_and_client_streams() as ( + client_stream, + server_stream, + client_port, + server_port, + ): + stream = {CLIENT: client_stream, SERVER: server_stream}[self.LOCAL] + remote_port = {CLIENT: server_port, SERVER: client_port}[self.LOCAL] + connection = Connection(self.nursery, stream, Protocol(self.LOCAL)) + self.assertEqual(connection.remote_address, ("127.0.0.1", remote_port)) + + async def test_state(self): + """Connection has a state attribute.""" + self.assertIs(self.connection.state, State.OPEN) + + async def test_request(self): + """Connection has a request attribute.""" + self.assertIsNone(self.connection.request) + + async def test_response(self): + """Connection has a response attribute.""" + self.assertIsNone(self.connection.response) + + async def test_subprotocol(self): + """Connection has a subprotocol attribute.""" + self.assertIsNone(self.connection.subprotocol) + + async def test_close_code(self): + """Connection has a close_code attribute.""" + self.assertIsNone(self.connection.close_code) + + async def test_close_reason(self): + """Connection has a close_reason attribute.""" + self.assertIsNone(self.connection.close_reason) + + # Test reporting of network errors. + + async def test_writing_in_recv_events_fails(self): + """Error when responding to incoming frames is correctly reported.""" + # Inject a fault by shutting down the stream for writing — but not the + # stream for reading because that would terminate the connection. + self.connection.stream.send_stream.close() + + # Receive a ping. Responding with a pong will fail. + await self.remote_connection.ping() + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.recv() + + self.assertIsInstance(raised.exception.__cause__, trio.ClosedResourceError) + + async def test_writing_in_send_context_fails(self): + """Error when sending outgoing frame is correctly reported.""" + # Inject a fault by shutting down the stream for writing — but not the + # stream for reading because that would terminate the connection. + self.connection.stream.send_stream.close() + + # Sending a pong will fail. + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.pong() + + self.assertIsInstance(raised.exception.__cause__, trio.ClosedResourceError) + + # Test safety nets — catching all exceptions in case of bugs. + + # Inject a fault in a random call in recv_events(). + # This test is tightly coupled to the implementation. + @patch("websockets.protocol.Protocol.events_received", side_effect=AssertionError) + async def test_unexpected_failure_in_recv_events(self, events_received): + """Unexpected internal error in recv_events() is correctly reported.""" + # Receive a message to trigger the fault. + await self.remote_connection.send("😀") + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.recv() + + self.assertIsInstance(raised.exception.__cause__, AssertionError) + + # Inject a fault in a random call in send_context(). + # This test is tightly coupled to the implementation. + @patch("websockets.protocol.Protocol.send_text", side_effect=AssertionError) + async def test_unexpected_failure_in_send_context(self, send_text): + """Unexpected internal error in send_context() is correctly reported.""" + # Send a message to trigger the fault. + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.send("😀") + + self.assertIsInstance(raised.exception.__cause__, AssertionError) + + +class ServerConnectionTests(ClientConnectionTests): + LOCAL = SERVER + REMOTE = CLIENT diff --git a/tests/trio/test_messages.py b/tests/trio/test_messages.py new file mode 100644 index 000000000..838b52bdb --- /dev/null +++ b/tests/trio/test_messages.py @@ -0,0 +1,633 @@ +import unittest +import unittest.mock + +import trio.testing + +from websockets.asyncio.compatibility import aiter, anext +from websockets.exceptions import ConcurrencyError +from websockets.frames import OP_BINARY, OP_CONT, OP_TEXT, Frame +from websockets.trio.messages import * + +from ..asyncio.utils import alist +from ..utils import MS +from .utils import IsolatedTrioTestCase + + +class AssemblerTests(IsolatedTrioTestCase): + def setUp(self): + self.pause = unittest.mock.Mock() + self.resume = unittest.mock.Mock() + self.assembler = Assembler(high=2, low=1, pause=self.pause, resume=self.resume) + + # Test get + + async def test_get_text_message_already_received(self): + """get returns a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_get_binary_message_already_received(self): + """get returns a binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = await self.assembler.get() + self.assertEqual(message, b"tea") + + async def test_get_text_message_not_received_yet(self): + """get returns a text message when it is received.""" + message = None + + async def get_task(): + nonlocal message + message = await self.assembler.get() + + async with trio.open_nursery() as nursery: + nursery.start_soon(get_task) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + self.assertEqual(message, "café") + + async def test_get_binary_message_not_received_yet(self): + """get returns a binary message when it is received.""" + message = None + + async def get_task(): + nonlocal message + message = await self.assembler.get() + + async with trio.open_nursery() as nursery: + nursery.start_soon(get_task) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_BINARY, b"tea")) + + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_already_received(self): + """get reassembles a fragmented a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_already_received(self): + """get reassembles a fragmented binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + message = await self.assembler.get() + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_not_received_yet(self): + """get reassembles a fragmented text message when it is received.""" + message = None + + async def get_task(): + nonlocal message + message = await self.assembler.get() + + async with trio.open_nursery() as nursery: + nursery.start_soon(get_task) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_not_received_yet(self): + """get reassembles a fragmented binary message when it is received.""" + message = None + + async def get_task(): + nonlocal message + message = await self.assembler.get() + + async with trio.open_nursery() as nursery: + nursery.start_soon(get_task) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_being_received(self): + """get reassembles a fragmented text message that is partially received.""" + message = None + + async def get_task(): + nonlocal message + message = await self.assembler.get() + + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + async with trio.open_nursery() as nursery: + nursery.start_soon(get_task) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_being_received(self): + """get reassembles a fragmented binary message that is partially received.""" + message = None + + async def get_task(): + nonlocal message + message = await self.assembler.get() + + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + async with trio.open_nursery() as nursery: + nursery.start_soon(get_task) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + self.assertEqual(message, b"tea") + + async def test_get_encoded_text_message(self): + """get returns a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = await self.assembler.get(decode=False) + self.assertEqual(message, b"caf\xc3\xa9") + + async def test_get_decoded_binary_message(self): + """get returns a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = await self.assembler.get(decode=True) + self.assertEqual(message, "tea") + + async def test_get_resumes_reading(self): + """get resumes reading when queue goes below the low-water mark.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + + # queue is above the low-water mark + await self.assembler.get() + self.resume.assert_not_called() + + # queue is at the low-water mark + await self.assembler.get() + self.resume.assert_called_once_with() + + # queue is below the low-water mark + await self.assembler.get() + self.resume.assert_called_once_with() + + async def test_get_does_not_resume_reading(self): + """get does not resume reading when the low-water mark is unset.""" + self.assembler.low = None + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + await self.assembler.get() + await self.assembler.get() + await self.assembler.get() + + self.resume.assert_not_called() + + async def test_cancel_get_before_first_frame(self): + """get can be canceled safely before reading the first frame.""" + + async def get_task(): + await self.assembler.get() + + with trio.move_on_after(MS) as cancel_scope: + async with trio.open_nursery() as nursery: + nursery.start_soon(get_task) + self.assertTrue(cancel_scope.cancelled_caught) + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_cancel_get_after_first_frame(self): + """get can be canceled safely after reading the first frame.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + + async def get_task(): + await self.assembler.get() + + with trio.move_on_after(MS) as cancel_scope: + async with trio.open_nursery() as nursery: + nursery.start_soon(get_task) + self.assertTrue(cancel_scope.cancelled_caught) + + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + + message = await self.assembler.get() + self.assertEqual(message, "café") + + # Test get_iter + + async def test_get_iter_text_message_already_received(self): + """get_iter yields a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + async def test_get_iter_binary_message_already_received(self): + """get_iter yields a binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, [b"tea"]) + + async def test_get_iter_text_message_not_received_yet(self): + """get_iter yields a text message when it is received.""" + fragments = None + + async def get_iter_task(): + nonlocal fragments + fragments = await alist(self.assembler.get_iter()) + + async with trio.open_nursery() as nursery: + nursery.start_soon(get_iter_task) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + self.assertEqual(fragments, ["café"]) + + async def test_get_iter_binary_message_not_received_yet(self): + """get_iter yields a binary message when it is received.""" + fragments = None + + async def get_iter_task(): + nonlocal fragments + fragments = await alist(self.assembler.get_iter()) + + async with trio.open_nursery() as nursery: + nursery.start_soon(get_iter_task) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_BINARY, b"tea")) + + self.assertEqual(fragments, [b"tea"]) + + async def test_get_iter_fragmented_text_message_already_received(self): + """get_iter yields a fragmented text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["ca", "f", "é"]) + + async def test_get_iter_fragmented_binary_message_already_received(self): + """get_iter yields a fragmented binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + async def test_get_iter_fragmented_text_message_not_received_yet(self): + """get_iter yields a fragmented text message when it is received.""" + iterator = aiter(self.assembler.get_iter()) + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assertEqual(await anext(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(await anext(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(await anext(iterator), "é") + await iterator.aclose() + + async def test_get_iter_fragmented_binary_message_not_received_yet(self): + """get_iter yields a fragmented binary message when it is received.""" + iterator = aiter(self.assembler.get_iter()) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assertEqual(await anext(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(await anext(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(await anext(iterator), b"a") + await iterator.aclose() + + async def test_get_iter_fragmented_text_message_being_received(self): + """get_iter yields a fragmented text message that is partially received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + iterator = aiter(self.assembler.get_iter()) + self.assertEqual(await anext(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(await anext(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(await anext(iterator), "é") + await iterator.aclose() + + async def test_get_iter_fragmented_binary_message_being_received(self): + """get_iter yields a fragmented binary message that is partially received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + iterator = aiter(self.assembler.get_iter()) + self.assertEqual(await anext(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(await anext(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(await anext(iterator), b"a") + await iterator.aclose() + + async def test_get_iter_encoded_text_message(self): + """get_iter yields a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + fragments = await alist(self.assembler.get_iter(decode=False)) + self.assertEqual(fragments, [b"ca", b"f\xc3", b"\xa9"]) + + async def test_get_iter_decoded_binary_message(self): + """get_iter yields a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + fragments = await alist(self.assembler.get_iter(decode=True)) + self.assertEqual(fragments, ["t", "e", "a"]) + + async def test_get_iter_resumes_reading(self): + """get_iter resumes reading when queue goes below the low-water mark.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + iterator = aiter(self.assembler.get_iter()) + + # queue is above the low-water mark + await anext(iterator) + self.resume.assert_not_called() + + # queue is at the low-water mark + await anext(iterator) + self.resume.assert_called_once_with() + + # queue is below the low-water mark + await anext(iterator) + self.resume.assert_called_once_with() + + await iterator.aclose() + + async def test_get_iter_does_not_resume_reading(self): + """get_iter does not resume reading when the low-water mark is unset.""" + self.assembler.low = None + + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + iterator = aiter(self.assembler.get_iter()) + await anext(iterator) + await anext(iterator) + await anext(iterator) + await iterator.aclose() + + self.resume.assert_not_called() + + async def test_cancel_get_iter_before_first_frame(self): + """get_iter can be canceled safely before reading the first frame.""" + + async def get_iter_task(): + await alist(self.assembler.get_iter()) + + with trio.move_on_after(MS) as cancel_scope: + async with trio.open_nursery() as nursery: + nursery.start_soon(get_iter_task) + self.assertTrue(cancel_scope.cancelled_caught) + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + async def test_cancel_get_iter_after_first_frame(self): + """get_iter cannot be canceled after reading the first frame.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + + async def get_iter_task(): + await alist(self.assembler.get_iter()) + + with trio.move_on_after(MS) as cancel_scope: + async with trio.open_nursery() as nursery: + nursery.start_soon(get_iter_task) + self.assertTrue(cancel_scope.cancelled_caught) + + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + + with self.assertRaises(ConcurrencyError): + await alist(self.assembler.get_iter()) + + # Test put + + async def test_put_pauses_reading(self): + """put pauses reading when queue goes above the high-water mark.""" + # queue is below the high-water mark + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.pause.assert_not_called() + + # queue is at the high-water mark + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.pause.assert_called_once_with() + + # queue is above the high-water mark + self.assembler.put(Frame(OP_CONT, b"a")) + self.pause.assert_called_once_with() + + async def test_put_does_not_pause_reading(self): + """put does not pause reading when the high-water mark is unset.""" + self.assembler.high = None + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + self.pause.assert_not_called() + + # Test termination + + async def test_get_fails_when_interrupted_by_close(self): + """get raises EOFError when close is called.""" + + async def closer(): + self.assembler.close() + + async with trio.open_nursery() as nursery: + nursery.start_soon(closer) + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_fails_when_interrupted_by_close(self): + """get_iter raises EOFError when close is called.""" + + async def closer(): + self.assembler.close() + + async with trio.open_nursery() as nursery: + nursery.start_soon(closer) + with self.assertRaises(EOFError): + async for _ in self.assembler.get_iter(): + self.fail("no fragment expected") + + async def test_get_fails_after_close(self): + """get raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_fails_after_close(self): + """get_iter raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + async for _ in self.assembler.get_iter(): + self.fail("no fragment expected") + + async def test_get_queued_message_after_close(self): + """get returns a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_get_iter_queued_message_after_close(self): + """get_iter yields a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + async def test_get_queued_fragmented_message_after_close(self): + """get reassembles a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + self.assembler.close() + message = await self.assembler.get() + self.assertEqual(message, b"tea") + + async def test_get_iter_queued_fragmented_message_after_close(self): + """get_iter yields a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + async def test_get_partially_queued_fragmented_message_after_close(self): + """get raises EOFError on a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_partially_queued_fragmented_message_after_close(self): + """get_iter yields a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + fragments = [] + with self.assertRaises(EOFError): + async for fragment in self.assembler.get_iter(): + fragments.append(fragment) + self.assertEqual(fragments, [b"t", b"e"]) + + async def test_put_fails_after_close(self): + """put raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + async def test_close_is_idempotent(self): + """close can be called multiple times safely.""" + self.assembler.close() + self.assembler.close() + + # Test (non-)concurrency + + async def test_get_fails_when_get_is_running(self): + """get cannot be called concurrently.""" + + async def get_task(): + await self.assembler.get() + + with trio.testing.RaisesGroup(ConcurrencyError): + async with trio.open_nursery() as nursery: + nursery.start_soon(get_task) + nursery.start_soon(get_task) + + async def test_get_fails_when_get_iter_is_running(self): + """get cannot be called concurrently with get_iter.""" + + async def get_task(): + await alist(self.assembler.get_iter()) + + async def get_iter_task(): + await alist(self.assembler.get_iter()) + + with trio.testing.RaisesGroup(ConcurrencyError): + async with trio.open_nursery() as nursery: + nursery.start_soon(get_iter_task) + nursery.start_soon(get_task) + + async def test_get_iter_fails_when_get_is_running(self): + """get_iter cannot be called concurrently with get.""" + + async def get_task(): + await alist(self.assembler.get_iter()) + + async def get_iter_task(): + await alist(self.assembler.get_iter()) + + with trio.testing.RaisesGroup(ConcurrencyError): + async with trio.open_nursery() as nursery: + nursery.start_soon(get_task) + nursery.start_soon(get_iter_task) + + async def test_get_iter_fails_when_get_iter_is_running(self): + """get_iter cannot be called concurrently.""" + + async def get_iter_task(): + await alist(self.assembler.get_iter()) + + with trio.testing.RaisesGroup(ConcurrencyError): + async with trio.open_nursery() as nursery: + nursery.start_soon(get_iter_task) + nursery.start_soon(get_iter_task) + + # Test setting limits + + async def test_set_high_water_mark(self): + """high sets the high-water and low-water marks.""" + assembler = Assembler(high=10) + self.assertEqual(assembler.high, 10) + self.assertEqual(assembler.low, 2) + + async def test_set_low_water_mark(self): + """low sets the low-water and high-water marks.""" + assembler = Assembler(low=5) + self.assertEqual(assembler.low, 5) + self.assertEqual(assembler.high, 20) + + async def test_set_high_and_low_water_marks(self): + """high and low set the high-water and low-water marks.""" + assembler = Assembler(high=10, low=5) + self.assertEqual(assembler.high, 10) + self.assertEqual(assembler.low, 5) + + async def test_unset_high_and_low_water_marks(self): + """High-water and low-water marks are unset.""" + assembler = Assembler() + self.assertEqual(assembler.high, None) + self.assertEqual(assembler.low, None) + + async def test_set_invalid_high_water_mark(self): + """high must be a non-negative integer.""" + with self.assertRaises(ValueError): + Assembler(high=-1) + + async def test_set_invalid_low_water_mark(self): + """low must be higher than high.""" + with self.assertRaises(ValueError): + Assembler(low=10, high=5) diff --git a/tests/trio/utils.py b/tests/trio/utils.py new file mode 100644 index 000000000..4686a74e6 --- /dev/null +++ b/tests/trio/utils.py @@ -0,0 +1,58 @@ +import asyncio +import functools +import sys +import unittest + +import trio.testing + + +if sys.version_info[:2] < (3, 11): # pragma: no cover + from exceptiongroup import ExceptionGroup + + +class IsolatedTrioTestCase(unittest.TestCase): + """ + Wrap test coroutines with :func:`trio.testing.trio_test` automatically. + + Also initializes a nursery for each test and adds :meth:`asyncSetUp` and + :meth:`asyncTearDown`, similar to :class:`unittest.IsolatedAsyncioTestCase`. + + """ + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + for name in unittest.defaultTestLoader.getTestCaseNames(cls): + test = getattr(cls, name) + if getattr(test, "converted_to_trio", False): + return + assert asyncio.iscoroutinefunction(test) + setattr(cls, name, cls.convert_to_trio(test)) + + @staticmethod + def convert_to_trio(test): + @trio.testing.trio_test + @functools.wraps(test) + async def new_test(self, *args, **kwargs): + try: + # Provide a nursery so it's easy to start tasks. + async with trio.open_nursery() as self.nursery: + await self.asyncSetUp() + try: + return await test(self, *args, **kwargs) + finally: + await self.asyncTearDown() + except ExceptionGroup as exc_group: + # Unwrap exceptions like unittest.SkipTest. + if len(exc_group.exceptions) == 1: + raise exc_group.exceptions[0] + else: # pragma: no cover + raise + + new_test.converted_to_trio = True + return new_test + + async def asyncSetUp(self): + pass + + async def asyncTearDown(self): + pass diff --git a/tox.ini b/tox.ini index 9450e9714..7f7d4e101 100644 --- a/tox.ini +++ b/tox.ini @@ -17,6 +17,7 @@ pass_env = deps = py311,py312,py313,coverage,maxi_cov: mitmproxy py311,py312,py313,coverage,maxi_cov: python-socks[asyncio] + trio werkzeug [testenv:coverage] @@ -48,4 +49,5 @@ commands = deps = mypy python-socks + trio werkzeug