From 1272208060bb45489a84d9a178c0d115de4cf1b4 Mon Sep 17 00:00:00 2001 From: WSH032 <614337162@qq.com> Date: Tue, 2 Apr 2024 06:15:43 +0800 Subject: [PATCH 01/11] feat: support `Anyio` for websocket --- pyproject.toml | 5 +- src/fastapi_proxy_lib/core/websocket.py | 348 ++++++++++-------------- 2 files changed, 153 insertions(+), 200 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f3c9bbf..43d46dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,9 +50,11 @@ dynamic = ["version"] dependencies = [ "httpx", - "httpx-ws >= 0.4.2", + "httpx-ws >= 0.5.2", "starlette", "typing_extensions >=4.5.0", + "anyio >= 4", + "exceptiongroup", ] [project.optional-dependencies] @@ -97,7 +99,6 @@ dependencies = [ "pytest-cov == 4.*", "uvicorn[standard] < 1.0.0", # TODO: Once it releases version 1.0.0, we will remove this restriction. "httpx[http2]", # we don't set version here, instead set it in `[project].dependencies`. - "anyio", # we don't set version here, because fastapi has a dependency on it "asgi-lifespan==2.*", "pytest-timeout==2.*", ] diff --git a/src/fastapi_proxy_lib/core/websocket.py b/src/fastapi_proxy_lib/core/websocket.py index 36dacd7..6e36213 100644 --- a/src/fastapi_proxy_lib/core/websocket.py +++ b/src/fastapi_proxy_lib/core/websocket.py @@ -1,34 +1,29 @@ """The websocket proxy lib.""" -import asyncio import logging +import warnings from contextlib import AsyncExitStack +from textwrap import dedent from typing import ( TYPE_CHECKING, Any, List, Literal, - NamedTuple, NoReturn, Optional, Union, ) +import anyio import httpx import httpx_ws import starlette.websockets as starlette_ws -from httpx_ws._api import ( # HACK: 注意,这个是私有模块 - DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, - DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, - DEFAULT_MAX_MESSAGE_SIZE_BYTES, - DEFAULT_QUEUE_SIZE, -) +from exceptiongroup import ExceptionGroup from starlette import status as starlette_status -from starlette.exceptions import WebSocketException as StarletteWebSocketException from starlette.responses import Response as StarletteResponse from starlette.responses import StreamingResponse from starlette.types import Scope -from typing_extensions import TypeAlias, override +from typing_extensions import override from wsproto.events import BytesMessage as WsprotoBytesMessage from wsproto.events import TextMessage as WsprotoTextMessage @@ -39,6 +34,40 @@ check_http_version, ) +# XXX: because these variables are private, we have to use try-except to avoid errors +try: + from httpx_ws._api import ( + DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + DEFAULT_MAX_MESSAGE_SIZE_BYTES, + DEFAULT_QUEUE_SIZE, + ) +except ImportError: + # ref: https://github.com/frankie567/httpx-ws/blob/b2135792141b71551b022ff0d76542a0263a890c/httpx_ws/_api.py#L31-L34 + DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS = ( # pyright: ignore[reportConstantRedefinition] + 20.0 + ) + DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS = ( # pyright: ignore[reportConstantRedefinition] + 20.0 + ) + DEFAULT_MAX_MESSAGE_SIZE_BYTES = ( # pyright: ignore[reportConstantRedefinition] + 65_536 + ) + DEFAULT_QUEUE_SIZE = 512 # pyright: ignore[reportConstantRedefinition] + + msg = dedent( + """\ + Can not import the default httpx_ws arguments, please open an issue on: + https://github.com/WSH032/fastapi-proxy-lib\ + """ + ) + warnings.warn( + msg, + RuntimeWarning, + stacklevel=1, + ) + + __all__ = ( "BaseWebSocketProxy", "ReverseWebSocketProxy", @@ -52,16 +81,9 @@ #################### Data Model #################### -_ClentToServerTaskType: TypeAlias = "asyncio.Task[starlette_ws.WebSocketDisconnect]" -_ServerToClientTaskType: TypeAlias = "asyncio.Task[httpx_ws.WebSocketDisconnect]" - - -class _ClientServerProxyTask(NamedTuple): - """The task group for passing websocket message between client and target server.""" - - client_to_server_task: _ClentToServerTaskType - server_to_client_task: _ServerToClientTaskType - +_WsExceptionGroupType = ExceptionGroup[ + Union[starlette_ws.WebSocketDisconnect, httpx_ws.WebSocketDisconnect, Exception] +] #################### Constant #################### @@ -106,36 +128,36 @@ async def _starlette_ws_receive_bytes_or_str( """Receive bytes or str from starlette WebSocket. - There is already a queue inside to store the received data - - Even if Exception is raised, the {WebSocket} would **not** be closed automatically, you should close it manually + - Even if `AssertionError` is raised, the `WebSocket` would **not** be closed automatically, + you should close it manually, Args: websocket: The starlette WebSocket that has been connected. - "has been connected" measn that you have called "websocket.accept" first. + "has been connected" means that you have called "websocket.accept" first. Raises: starlette.websockets.WebSocketDisconnect: If the WebSocket is disconnected. WebSocketDisconnect.code is the close code. WebSocketDisconnect.reason is the close reason. - **This is normal behavior that you should catch** - StarletteWebSocketException: If receive a invalid message type which is neither bytes nor str. - StarletteWebSocketException.code = starlette_status.WS_1008_POLICY_VIOLATION - StarletteWebSocketException.reason is the close reason. - - RuntimeError: If the WebSocket is not connected. Need to call "accept" first. - If the {websocket} argument you passed in is correct, this error will never be raised, just for asset. + AssertionError: + - If receive a invalid message type which is neither bytes nor str. + - RuntimeError: If the WebSocket is not connected. Need to call "accept" first. + If the `websocket` argument passed in is correct, this error will never be raised, just for assertion. Returns: bytes | str: The received data. """ - # 实现参考: + # Implement reference: # https://github.com/encode/starlette/blob/657e7e7b728e13dc66cc3f77dffd00a42545e171/starlette/websockets.py#L107C1-L115C1 assert ( websocket.application_state == starlette_ws.WebSocketState.CONNECTED ), """WebSocket is not connected. Need to call "accept" first.""" message = await websocket.receive() - # maybe raise WebSocketDisconnect - websocket._raise_on_disconnect(message) # pyright: ignore [reportPrivateUsage] + + if message["type"] == "websocket.disconnect": + raise starlette_ws.WebSocketDisconnect(message["code"], message.get("reason")) # https://asgi.readthedocs.io/en/latest/specs/www.html#receive-receive-event if message.get("bytes") is not None: @@ -143,12 +165,8 @@ async def _starlette_ws_receive_bytes_or_str( elif message.get("text") is not None: return message["text"] else: - # 这种情况应该不会发生,因为这是ASGI标准 + # It should never happen, because of the ASGI spec raise AssertionError("message should have 'bytes' or 'text' key") - raise StarletteWebSocketException( - code=starlette_status.WS_1008_POLICY_VIOLATION, - reason="Invalid message type received (neither bytes nor text).", - ) # 为什么使用这个函数而不是直接使用httpx_ws_AsyncWebSocketSession.receive_text() @@ -159,8 +177,8 @@ async def _httpx_ws_receive_bytes_or_str( """Receive bytes or str from httpx_ws AsyncWebSocketSession . - There is already a queue inside to store the received data - - Even if Exception is raised, the {WebSocket} would **not** be closed automatically, you should close it manually - - except for httpx_ws.WebSocketNetworkError, which will call 'close' automatically + - Even if `AssertionError` or `httpx_ws.WebSocketNetworkError` is raised, the `WebSocket` would **not** be closed automatically, + you should close it manually, Args: websocket: The httpx_ws AsyncWebSocketSession that has been connected. @@ -171,9 +189,8 @@ async def _httpx_ws_receive_bytes_or_str( WebSocketDisconnect.reason is the close reason. - **This is normal behavior that you should catch** httpx_ws.WebSocketNetworkError: A network error occurred. - - httpx_ws.WebSocketInvalidTypeReceived: If receive a invalid message type which is neither bytes nor str. - Usually it will never be raised, just for assert + AssertionError: If receive a invalid message type which is neither bytes nor str. + Usually it will never be raised, just for assertion Returns: bytes | str: The received data. @@ -198,7 +215,7 @@ async def _httpx_ws_receive_bytes_or_str( else: # pragma: no cover # 无法测试这个分支,因为无法发送这种消息,正常来说也不会被执行,所以我们这里记录critical msg = f"Invalid message type received: {type(event)}" logging.critical(msg) - raise httpx_ws.WebSocketInvalidTypeReceived(event) + raise AssertionError(event) async def _httpx_ws_send_bytes_or_str( @@ -207,10 +224,10 @@ async def _httpx_ws_send_bytes_or_str( ) -> None: """Send bytes or str to WebSocket. - - Usually, when Exception is raised, the {WebSocket} is already closed. + - Usually, when Exception is raised, the `WebSocket` is already closed. Args: - websocket: The httpx_ws.AsyncWebSocketSession that has been connected. + websocket: The `httpx_ws.AsyncWebSocketSession` that has been connected. data: The data to send. Raises: @@ -236,7 +253,7 @@ async def _starlette_ws_send_bytes_or_str( ) -> None: """Send bytes or str to WebSocket. - - Even if Exception is raised, the {WebSocket} would **not** be closed automatically, you should close it manually + - Even if Exception is raised, the `WebSocket` would **not** be closed automatically, you should close it manually Args: websocket: The starlette_ws.WebSocket that has been connected. @@ -261,8 +278,8 @@ async def _starlette_ws_send_bytes_or_str( async def _wait_client_then_send_to_server( - *, client_ws: starlette_ws.WebSocket, server_ws: httpx_ws.AsyncWebSocketSession -) -> starlette_ws.WebSocketDisconnect: + client_ws: starlette_ws.WebSocket, server_ws: httpx_ws.AsyncWebSocketSession +) -> NoReturn: """Receive data from client, then send to target server. Args: @@ -270,24 +287,22 @@ async def _wait_client_then_send_to_server( server_ws: The websocket which send data to target server. Returns: - If the client_ws sends a shutdown message normally, will return starlette_ws.WebSocketDisconnect. + NoReturn: Never return. Always run forever, except encounter an error, then raise it. Raises: - error for receiving: refer to `_starlette_ws_receive_bytes_or_str` - error for sending: refer to `_httpx_ws_send_bytes_or_str` + error for receiving: refer to `_starlette_ws_receive_bytes_or_str`. + starlette.websockets.WebSocketDisconnect: If the WebSocket is disconnected. + - **This is normal behavior that you should catch**. + error for sending: refer to `_httpx_ws_send_bytes_or_str`. """ while True: - try: - receive = await _starlette_ws_receive_bytes_or_str(client_ws) - except starlette_ws.WebSocketDisconnect as e: - return e - else: - await _httpx_ws_send_bytes_or_str(server_ws, receive) + receive = await _starlette_ws_receive_bytes_or_str(client_ws) + await _httpx_ws_send_bytes_or_str(server_ws, receive) async def _wait_server_then_send_to_client( - *, client_ws: starlette_ws.WebSocket, server_ws: httpx_ws.AsyncWebSocketSession -) -> httpx_ws.WebSocketDisconnect: + client_ws: starlette_ws.WebSocket, server_ws: httpx_ws.AsyncWebSocketSession +) -> NoReturn: """Receive data from target server, then send to client. Args: @@ -295,102 +310,84 @@ async def _wait_server_then_send_to_client( server_ws: The websocket which receive data of target server. Returns: - If the server_ws sends a shutdown message normally, will return httpx_ws.WebSocketDisconnect. + NoReturn: Never return. Always run forever, except encounter an error, then raise it. Raises: - error for receiving: refer to `_httpx_ws_receive_bytes_or_str` - error for sending: refer to `_starlette_ws_send_bytes_or_str` + error for receiving: refer to `_httpx_ws_receive_bytes_or_str`. + httpx_ws.WebSocketDisconnect: If the WebSocket is disconnected. + - **This is normal behavior that you should catch** + error for sending: refer to `_starlette_ws_send_bytes_or_str`. """ while True: - try: - receive = await _httpx_ws_receive_bytes_or_str(server_ws) - except httpx_ws.WebSocketDisconnect as e: - return e - else: - await _starlette_ws_send_bytes_or_str(client_ws, receive) + receive = await _httpx_ws_receive_bytes_or_str(server_ws) + await _starlette_ws_send_bytes_or_str(client_ws, receive) async def _close_ws( + excgroup: _WsExceptionGroupType, + /, *, - client_to_server_task: _ClentToServerTaskType, - server_to_client_task: _ServerToClientTaskType, client_ws: starlette_ws.WebSocket, server_ws: httpx_ws.AsyncWebSocketSession, ) -> None: - """Close ws connection and send status code based on task results. - - - If there is an error, or can't get status code from tasks, then always send a 1011 status code - - Will close ws connection whatever happens. + """Close ws connection and send status code based on `excgroup`. Args: - client_to_server_task: client_to_server_task - server_to_client_task: server_to_client_task + excgroup: The exception group raised when running both client and server proxy tasks. + There should be at most 2 exceptions, one for client, one for server. + If contains `starlette_ws.WebSocketDisconnect`, then will close `server_ws`; + If contains `httpx_ws.WebSocketDisconnect`, then will close `client_ws`. + Else, will close both ws connections with status code `1011`. client_ws: client_ws server_ws: server_ws + """ - try: - # NOTE: 先判断 cancelled ,因为被取消的 task.exception() 会引发异常 - client_error = ( - asyncio.CancelledError - if client_to_server_task.cancelled() - else client_to_server_task.exception() + assert ( + len(excgroup.exceptions) <= 2 + ), "There should be at most 2 exceptions, one for client, one for server." + + client_ws_disc_group = ( + excgroup.subgroup( # pyright: ignore[reportUnknownMemberType] + starlette_ws.WebSocketDisconnect ) - server_error = ( - asyncio.CancelledError - if server_to_client_task.cancelled() - else server_to_client_task.exception() + ) + if client_ws_disc_group: + client_disconnect = client_ws_disc_group.exceptions[0] + # XXX: `isinstance` to make pyright happy + assert isinstance(client_disconnect, starlette_ws.WebSocketDisconnect) + return await server_ws.close(client_disconnect.code, client_disconnect.reason) + + server_ws_disc_group = ( + excgroup.subgroup( # pyright: ignore[reportUnknownMemberType] + httpx_ws.WebSocketDisconnect ) - - if client_error is None: - # clinet端收到正常关闭消息,则关闭server端 - disconnection = client_to_server_task.result() - await server_ws.close(disconnection.code, disconnection.reason) - return - elif server_error is None: - # server端收到正常关闭消息,则关闭client端 - disconnection = server_to_client_task.result() - await client_ws.close(disconnection.code, disconnection.reason) - return - else: - # 如果上述情况都没有发生,意味着至少其中一个任务发生了异常,导致了另一个任务被取消 - # NOTE: 我们不在这个分支调用 `ws.close`,而是留到最后的 finally 来关闭 - client_info = client_ws.client - client_host, client_port = ( - (client_info.host, client_info.port) - if client_info is not None - else (None, None) - ) - # 这里不用dedent是为了更好的性能 - msg = f"""\ -An error occurred in the websocket connection for {client_host}:{client_port}. -client_error: {client_error} -server_error: {server_error}\ + ) + if server_ws_disc_group: + server_disconnect = server_ws_disc_group.exceptions[0] + # XXX: `isinstance` to make pyright happy + assert isinstance(server_disconnect, httpx_ws.WebSocketDisconnect) + return await client_ws.close(server_disconnect.code, server_disconnect.reason) + + # 如果上述情况都没有发生,意味着至少其中一个任务发生了异常,导致了另一个任务被取消 + client_info = client_ws.client + client_host, client_port = ( + (client_info.host, client_info.port) + if client_info is not None + else (None, None) + ) + # 这里不用dedent是为了更好的性能 + msg = f"""\ +An error group occurred in the websocket connection for {client_host}:{client_port}. +error group: {excgroup.exceptions}\ """ - logging.warning(msg) + logging.warning(msg) - except ( - Exception - ) as e: # pragma: no cover # 这个分支是一个保险分支,通常无法执行,所以只进行记录 - logging.error( - f"{e} when close ws connection. client: {client_to_server_task}, server:{server_to_client_task}" - ) - raise - - finally: - # 无论如何,确保关闭两个websocket - # 状态码参考: https://developer.mozilla.org/zh-CN/docs/Web/API/CloseEvent - # https://datatracker.ietf.org/doc/html/rfc6455#section-7.4.1 - try: - await client_ws.close(starlette_status.WS_1011_INTERNAL_ERROR) - except Exception: - # 这个分支通常会被触发,因为uvicorn服务器在重复调用close时会引发异常 - pass - try: - await server_ws.close(starlette_status.WS_1011_INTERNAL_ERROR) - except Exception as e: # pragma: no cover - # 这个分支是一个保险分支,通常无法执行,所以只进行记录 - # 不会触发的原因是,负责服务端 ws 连接的 httpx_ws 支持重复调用close而不引发错误 - logging.debug("Unexpected error for debug", exc_info=e) + # Anyway, we should close both ws connections. + # Why we use `1011` code, refer to: + # https://developer.mozilla.org/zh-CN/docs/Web/API/CloseEvent + # https://datatracker.ietf.org/doc/html/rfc6455#section-7.4.1 + await client_ws.close(starlette_status.WS_1011_INTERNAL_ERROR) + await server_ws.close(starlette_status.WS_1011_INTERNAL_ERROR) #################### # #################### @@ -523,7 +520,7 @@ async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOv proxy_ws = await stack.enter_async_context( httpx_ws.aconnect_ws( - # 这个是httpx_ws类型注解的问题,其实是可以使用httpx.URL的 + # XXX: 这个是httpx_ws类型注解的问题,其实是可以使用httpx.URL的 url=target_url, # pyright: ignore [reportArgumentType] client=client, max_message_size_bytes=max_message_size_bytes, @@ -573,71 +570,26 @@ async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOv # headers=... ) - client_to_server_task = asyncio.create_task( - _wait_client_then_send_to_server( - client_ws=websocket, - server_ws=proxy_ws, - ), - name="client_to_server_task", - ) - server_to_client_task = asyncio.create_task( - _wait_server_then_send_to_client( - client_ws=websocket, - server_ws=proxy_ws, - ), - name="server_to_client_task", - ) - # 保持强引用: https://docs.python.org/zh-cn/3.12/library/asyncio-task.html#creating-tasks - task_group = _ClientServerProxyTask( - client_to_server_task=client_to_server_task, - server_to_client_task=server_to_client_task, - ) - - # NOTE: 考虑这两种情况: - # 1. 如果一个任务在发送阶段退出: - # 这意味着对应发送的ws已经关闭或者出错 - # 那么另一个任务很快就会在接收该ws的时候引发异常而退出 - # 很快,最终两个任务都结束 - # **这时候pending 可能 为空,而done为两个任务** - # 2. 如果一个任务在接收阶段退出: - # 这意味着对应接收的ws已经关闭或者发生出错 - # - 对于另一个任务的发送,可能会在发送的时候引发异常而退出 - # - 可能指的是: wsproto后端的uvicorn发送消息永远不会出错 - # - https://github.com/encode/uvicorn/discussions/2137 - # - 对于另一个任务的接收,可能会等待很久,才能继续进行发送任务而引发异常而退出 - # **这时候pending一般为一个未结束任务** - # - # 因为第二种情况的存在,所以需要用 wait_for 强制让其退出 - # 但考虑到第一种情况,先等它 1s ,看看能否正常退出 try: - _, pending = await asyncio.wait( - task_group, - return_when=asyncio.FIRST_COMPLETED, - ) - for ( - pending_task - ) in pending: # NOTE: pending 一般为一个未结束任务,或者为空 - # 开始取消未结束的任务 - try: - await asyncio.wait_for(pending_task, timeout=1) - except asyncio.TimeoutError: - logging.debug(f"{pending} TimeoutError, it's normal.") - except Exception as e: - # 取消期间可能另一个ws会发生异常,这个是正常情况,且会被 asyncio.wait_for 传播 - logging.debug( - f"{pending} raise error when being canceled, it's normal. error: {e}" - ) - except Exception as e: # pragma: no cover # 这个是保险分支,通常无法执行 - logging.warning( - f"Something wrong, please contact the developer. error: {e}" - ) - raise - finally: - # 无论如何都要关闭两个websocket - # NOTE: 这时候两个任务都已经结束 + async with anyio.create_task_group() as tg: + tg.start_soon( + _wait_client_then_send_to_server, + websocket, + proxy_ws, + name="client_to_server_task", + ) + tg.start_soon( + _wait_server_then_send_to_client, + websocket, + proxy_ws, + name="server_to_client_task", + ) + # XXX: `ExceptionGroup[Any]` is illegal, so we have to ignore the type issue + except ( + ExceptionGroup + ) as excgroup: # pyright: ignore[reportUnknownVariableType] await _close_ws( - client_to_server_task=client_to_server_task, - server_to_client_task=server_to_client_task, + excgroup, # pyright: ignore[reportUnknownArgumentType] client_ws=websocket, server_ws=proxy_ws, ) From 430acd680c654737727812b4a7ff334b16624497 Mon Sep 17 00:00:00 2001 From: WSH032 <614337162@qq.com> Date: Tue, 2 Apr 2024 06:32:06 +0800 Subject: [PATCH 02/11] feat: support `Anyio` totally --- src/fastapi_proxy_lib/core/_tool.py | 4 ++-- src/fastapi_proxy_lib/fastapi/router.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/fastapi_proxy_lib/core/_tool.py b/src/fastapi_proxy_lib/core/_tool.py index 5b3d014..1cfd386 100644 --- a/src/fastapi_proxy_lib/core/_tool.py +++ b/src/fastapi_proxy_lib/core/_tool.py @@ -309,8 +309,8 @@ def return_err_msg_response( err_response_json = ErrRseponseJson(detail=detail) # TODO: 请注意,logging是同步函数,每次会阻塞1ms左右,这可能会导致性能问题 - # 特别是对于写入文件的log,最好把它放到 asyncio.to_thread 里执行 - # https://docs.python.org/zh-cn/3/library/asyncio-task.html#coroutine + # 特别是对于写入文件的log,最好把它放到 `anyio.to_thread.run_sync()` 里执行 + # https://anyio.readthedocs.io/en/stable/threads.html#running-a-function-in-a-worker-thread if logger is not None: # 只要传入了logger,就一定记录日志 diff --git a/src/fastapi_proxy_lib/fastapi/router.py b/src/fastapi_proxy_lib/fastapi/router.py index 076f69d..32ee808 100644 --- a/src/fastapi_proxy_lib/fastapi/router.py +++ b/src/fastapi_proxy_lib/fastapi/router.py @@ -3,7 +3,6 @@ The low-level API for [fastapi_proxy_lib.fastapi.app][]. """ -import asyncio import warnings from contextlib import asynccontextmanager from typing import ( @@ -18,6 +17,7 @@ Union, ) +import anyio from fastapi import APIRouter from starlette.requests import Request from starlette.responses import Response @@ -273,6 +273,8 @@ async def shutdown_clients(*_: Any, **__: Any) -> AsyncIterator[None]: When __aexit__ is called, will close all registered proxy. """ yield - await asyncio.gather(*[proxy.aclose() for proxy in self.registered_proxy]) + async with anyio.create_task_group() as tg: + for proxy in self.registered_proxy: + tg.start_soon(proxy.aclose) return shutdown_clients From ef8da551f4afbc5c92d8163f8493daa239b40e3a Mon Sep 17 00:00:00 2001 From: WSH032 <614337162@qq.com> Date: Wed, 3 Apr 2024 06:46:46 +0800 Subject: [PATCH 03/11] test: using `anyio` instead of `asyncio` in tests And simplify the `UvicornServer` class in test --- pyproject.toml | 6 +- tests/app/echo_ws_app.py | 4 +- tests/app/tool.py | 187 ++++----------------------------------- tests/conftest.py | 9 +- tests/test_ws.py | 33 +++---- 5 files changed, 42 insertions(+), 197 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 43d46dc..b012ebc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,9 +98,11 @@ dependencies = [ "pytest == 7.*", "pytest-cov == 4.*", "uvicorn[standard] < 1.0.0", # TODO: Once it releases version 1.0.0, we will remove this restriction. + "hypercorn[trio] == 0.16.*", "httpx[http2]", # we don't set version here, instead set it in `[project].dependencies`. - "asgi-lifespan==2.*", - "pytest-timeout==2.*", + "asgi-lifespan == 2.*", + "pytest-timeout == 2.*", + "sniffio == 1.3.*", ] [tool.hatch.envs.default.scripts] diff --git a/tests/app/echo_ws_app.py b/tests/app/echo_ws_app.py index e4925e6..132e0bd 100644 --- a/tests/app/echo_ws_app.py +++ b/tests/app/echo_ws_app.py @@ -1,8 +1,8 @@ # ruff: noqa: D100 # pyright: reportUnusedFunction=false -import asyncio +import anyio from fastapi import FastAPI, WebSocket from starlette.websockets import WebSocketDisconnect @@ -76,7 +76,7 @@ async def just_close_with_1001(websocket: WebSocket): test_app_dataclass.request_dict["request"] = websocket await websocket.accept() - await asyncio.sleep(0.3) + await anyio.sleep(0.3) await websocket.close(1001) @app.websocket("/reject_handshake") diff --git a/tests/app/tool.py b/tests/app/tool.py index c740478..77a1909 100644 --- a/tests/app/tool.py +++ b/tests/app/tool.py @@ -1,18 +1,16 @@ # noqa: D100 -import asyncio -import socket +from contextlib import AsyncExitStack from dataclasses import dataclass -from typing import Any, Callable, List, Optional, Type, TypedDict, TypeVar, Union +from typing import Any, TypedDict, Union +import anyio import httpx import uvicorn from fastapi import FastAPI from starlette.requests import Request from starlette.websockets import WebSocket -from typing_extensions import Self, override - -_Decoratable_T = TypeVar("_Decoratable_T", bound=Union[Callable[..., Any], Type[Any]]) +from typing_extensions import Self ServerRecvRequestsTypes = Union[Request, WebSocket] @@ -46,180 +44,32 @@ def get_request(self) -> ServerRecvRequestsTypes: return server_recv_request -def _no_override_uvicorn_server(_method: _Decoratable_T) -> _Decoratable_T: - """Check if the method is already in `uvicorn.Server`.""" - assert not hasattr( - uvicorn.Server, _method.__name__ - ), f"Override method of `uvicorn.Server` cls : {_method.__name__}" - return _method - - -class AeixtTimeoutUndefine: - """Didn't set `contx_exit_timeout` in `aexit()`.""" - - -aexit_timeout_undefine = AeixtTimeoutUndefine() - - -# HACK: 不能继承 AbstractAsyncContextManager[Self] -# 目前有问题,继承 AbstractAsyncContextManager 的话pyright也推测不出来类型 -# 只能依靠 __aenter__ 和 __aexit__ 的类型注解 class UvicornServer(uvicorn.Server): - """subclass of `uvicorn.Server` which can use AsyncContext to launch and shutdown automatically. - - Attributes: - contx_server_task: The task of server. - contx_socket: The socket of server. - - other attributes are same as `uvicorn.Server`: - - config: The config arg that be passed in. - ... - """ + """subclass of `uvicorn.Server` which can use AsyncContext to launch and shutdown automatically.""" - _contx_server_task: Union["asyncio.Task[None]", None] - assert not hasattr(uvicorn.Server, "_contx_server_task") - - _contx_socket: Union[socket.socket, None] - assert not hasattr(uvicorn.Server, "_contx_socket") - - _contx_server_started_event: Union[asyncio.Event, None] - assert not hasattr(uvicorn.Server, "_contx_server_started_event") - - contx_exit_timeout: Union[int, float, None] - assert not hasattr(uvicorn.Server, "contx_exit_timeout") - - @override - def __init__( - self, config: uvicorn.Config, contx_exit_timeout: Union[int, float, None] = None - ) -> None: - """The same as `uvicorn.Server.__init__`.""" - super().__init__(config=config) - self._contx_server_task = None - self._contx_socket = None - self._contx_server_started_event = None - self.contx_exit_timeout = contx_exit_timeout - - @override - async def startup(self, sockets: Optional[List[socket.socket]] = None) -> None: - """The same as `uvicorn.Server.startup`.""" - super_return = await super().startup(sockets=sockets) - self.contx_server_started_event.set() - return super_return - - @_no_override_uvicorn_server - async def aenter(self) -> Self: + async def __aenter__(self) -> Self: """Launch the server.""" - # 在分配资源之前,先检查是否重入 - if self.contx_server_started_event.is_set(): - raise RuntimeError("DO not launch server by __aenter__ again!") - # FIXME: # 这个socket被设计为可被同一进程内的多个server共享,可能会引起潜在问题 - self._contx_socket = self.config.bind_socket() + self._socket = self.config.bind_socket() + self._exit_stack = AsyncExitStack() - self._contx_server_task = asyncio.create_task( - self.serve([self._contx_socket]), name=f"Uvicorn Server Task of {self}" + task_group = await self._exit_stack.enter_async_context( + anyio.create_task_group() + ) + task_group.start_soon( + self.serve, [self._socket], name=f"Uvicorn Server Task of {self}" ) - # 在 uvicorn.Server 的实现中,Server.serve() 内部会调用 Server.startup() 完成启动 - # 被覆盖的 self.startup() 会在完成时调用 self.contx_server_started_event.set() - await self.contx_server_started_event.wait() # 等待服务器确实启动后才返回 - return self - @_no_override_uvicorn_server - async def __aenter__(self) -> Self: - """Launch the server. + return self - The same as `self.aenter()`. - """ - return await self.aenter() - - @_no_override_uvicorn_server - async def aexit( - self, - contx_exit_timeout: Union[ - int, float, None, AeixtTimeoutUndefine - ] = aexit_timeout_undefine, - ) -> None: + async def __aexit__(self, *_: Any, **__: Any) -> None: """Shutdown the server.""" - contx_server_task = self.contx_server_task - contx_socket = self.contx_socket - - if isinstance(contx_exit_timeout, AeixtTimeoutUndefine): - contx_exit_timeout = self.contx_exit_timeout - # 在 uvicorn.Server 的实现中,设置 should_exit 可以使得 server 任务结束 - assert hasattr(self, "should_exit") + assert self.should_exit is False, "The server has already exited." self.should_exit = True - - try: - await asyncio.wait_for(contx_server_task, timeout=contx_exit_timeout) - except asyncio.TimeoutError: - print(f"{contx_server_task.get_name()} timeout!") - finally: - # 其实uvicorn.Server会自动关闭socket,这里是为了保险起见 - contx_socket.close() - - @_no_override_uvicorn_server - async def __aexit__(self, *_: Any, **__: Any) -> None: - """Shutdown the server. - - The same as `self.aexit()`. - """ - return await self.aexit() - - @property - @_no_override_uvicorn_server - def contx_server_started_event(self) -> asyncio.Event: - """The event that indicates the server has started. - - When first call the property, it will instantiate a `asyncio.Event()`to - `self._contx_server_started_event`. - - Warn: This is a internal implementation detail, do not change the event manually. - - please call the property in `self.aenter()` or `self.startup()` **first**. - - **Never** call it outside of an async event loop first: - https://stackoverflow.com/questions/53724665/using-queues-results-in-asyncio-exception-got-future-future-pending-attached - """ - if self._contx_server_started_event is None: - self._contx_server_started_event = asyncio.Event() - - return self._contx_server_started_event - - @property - @_no_override_uvicorn_server - def contx_socket(self) -> socket.socket: - """The socket of server. - - Note: must call `self.__aenter__()` first. - """ - if self._contx_socket is None: - raise RuntimeError("Please call `self.__aenter__()` first.") - else: - return self._contx_socket - - @property - @_no_override_uvicorn_server - def contx_server_task(self) -> "asyncio.Task[None]": - """The task of server. - - Note: must call `self.__aenter__()` first. - """ - if self._contx_server_task is None: - raise RuntimeError("Please call `self.__aenter__()` first.") - else: - return self._contx_server_task - - @property - @_no_override_uvicorn_server - def contx_socket_getname(self) -> Any: - """Utils for calling self.contx_socket.getsockname(). - - Return: - refer to: https://docs.python.org/zh-cn/3/library/socket.html#socket-families - """ - return self.contx_socket.getsockname() + await self._exit_stack.__aexit__(*_, **__) @property - @_no_override_uvicorn_server def contx_socket_url(self) -> httpx.URL: """If server is tcp socket, return the url of server. @@ -228,7 +78,8 @@ def contx_socket_url(self) -> httpx.URL: config = self.config if config.fd is not None or config.uds is not None: raise RuntimeError("Only support tcp socket.") - host, port = self.contx_socket_getname[:2] + # refer to: https://docs.python.org/zh-cn/3/library/socket.html#socket-families + host, port = self._socket.getsockname()[:2] return httpx.URL( host=host, port=port, diff --git a/tests/conftest.py b/tests/conftest.py index 0527101..147793e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,6 @@ Coroutine, Literal, Protocol, - Union, ) import pytest @@ -64,7 +63,7 @@ class LifeAppDataclass4Test(AppDataclass4Test): class UvicornServerFixture(Protocol): # noqa: D101 def __call__( # noqa: D102 - self, config: uvicorn.Config, contx_exit_timeout: Union[int, float, None] = None + self, config: uvicorn.Config ) -> Coroutine[None, None, UvicornServer]: ... @@ -199,11 +198,9 @@ async def uvicorn_server_fixture() -> AsyncIterator[UvicornServerFixture]: """ async with AsyncExitStack() as exit_stack: - async def uvicorn_server_fct( - config: uvicorn.Config, contx_exit_timeout: Union[int, float, None] = None - ) -> UvicornServer: + async def uvicorn_server_fct(config: uvicorn.Config) -> UvicornServer: uvicorn_server = await exit_stack.enter_async_context( - UvicornServer(config=config, contx_exit_timeout=contx_exit_timeout) + UvicornServer(config=config) ) return uvicorn_server diff --git a/tests/test_ws.py b/tests/test_ws.py index 2119719..5e33412 100644 --- a/tests/test_ws.py +++ b/tests/test_ws.py @@ -1,11 +1,11 @@ # noqa: D100 -import asyncio from contextlib import AsyncExitStack from multiprocessing import Process, Queue from typing import Any, Dict, Literal, Optional +import anyio import httpx import httpx_ws import pytest @@ -25,7 +25,6 @@ DEFAULT_HOST = "127.0.0.1" DEFAULT_PORT = 0 -DEFAULT_CONTX_EXIT_TIMEOUT = 5 # WS_BACKENDS_NEED_BE_TESTED = ("websockets", "wsproto") # # FIXME: wsproto 有问题,暂时不测试 @@ -56,14 +55,14 @@ def _subprocess_run_echo_ws_uvicorn_server(queue: "Queue[str]", **kwargs: Any): ) async def run(): - await target_ws_server.aenter() - url = str(target_ws_server.contx_socket_url) - queue.put(url) - queue.close() - while True: # run forever - await asyncio.sleep(0.1) + async with target_ws_server: + url = str(target_ws_server.contx_socket_url) + queue.put(url) + queue.close() + while True: # run forever + await anyio.sleep(0.1) - asyncio.run(run()) + anyio.run(run) def _subprocess_run_httpx_ws( @@ -96,9 +95,9 @@ async def run(): queue.put("done") queue.close() while True: # run forever - await asyncio.sleep(0.1) + await anyio.sleep(0.1) - asyncio.run(run()) + anyio.run(run) class TestReverseWsProxy(AbstractTestProxy): @@ -120,7 +119,6 @@ async def tool_4_test_fixture( # pyright: ignore[reportIncompatibleMethodOverri uvicorn.Config( echo_ws_app, port=DEFAULT_PORT, host=DEFAULT_HOST, ws=request.param ), - contx_exit_timeout=DEFAULT_CONTX_EXIT_TIMEOUT, ) target_server_base_url = str(target_ws_server.contx_socket_url) @@ -135,7 +133,6 @@ async def tool_4_test_fixture( # pyright: ignore[reportIncompatibleMethodOverri uvicorn.Config( reverse_ws_app, port=DEFAULT_PORT, host=DEFAULT_HOST, ws=request.param ), - contx_exit_timeout=DEFAULT_CONTX_EXIT_TIMEOUT, ) proxy_server_base_url = str(proxy_ws_server.contx_socket_url) @@ -226,7 +223,7 @@ async def test_ws_proxy(self, tool_4_test_fixture: Tool4TestFixture) -> None: # 避免从队列中get导致的异步阻塞 while aconnect_ws_subprocess_queue.empty(): - await asyncio.sleep(0.1) + await anyio.sleep(0.1) _ = aconnect_ws_subprocess_queue.get() # 获取到了即代表连接建立成功 # force shutdown client @@ -267,7 +264,7 @@ async def test_target_server_shutdown_abnormally( # 避免从队列中get导致的异步阻塞 while subprocess_queue.empty(): - await asyncio.sleep(0.1) + await anyio.sleep(0.1) target_server_base_url = subprocess_queue.get() client_for_conn_to_target_server = httpx.AsyncClient(proxies=NO_PROXIES) @@ -300,13 +297,11 @@ async def test_target_server_shutdown_abnormally( await ws0.receive() assert exce.value.code == 1011 - loop = asyncio.get_running_loop() - - seconde_ws_recv_start = loop.time() + seconde_ws_recv_start = anyio.current_time() with pytest.raises(httpx_ws.WebSocketDisconnect) as exce: await ws1.receive() assert exce.value.code == 1011 - seconde_ws_recv_end = loop.time() + seconde_ws_recv_end = anyio.current_time() # HACK: 由于收到关闭代码需要40s,目前无法确定是什么原因, # 所以目前会同时测试两个客户端的连接, From b35f2ab0adc4782b4259161644a482c9749e129e Mon Sep 17 00:00:00 2001 From: WSH032 <614337162@qq.com> Date: Wed, 3 Apr 2024 21:11:31 +0800 Subject: [PATCH 04/11] test: add `HypercornServer` and `TestServer` class for testing --- tests/app/tool.py | 141 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 139 insertions(+), 2 deletions(-) diff --git a/tests/app/tool.py b/tests/app/tool.py index 77a1909..75fbf07 100644 --- a/tests/app/tool.py +++ b/tests/app/tool.py @@ -2,12 +2,23 @@ from contextlib import AsyncExitStack from dataclasses import dataclass -from typing import Any, TypedDict, Union +from typing import Any, Literal, TypedDict, Union import anyio import httpx +import sniffio import uvicorn from fastapi import FastAPI +from hypercorn import Config as HyperConfig +from hypercorn.asyncio.run import ( + worker_serve as hyper_aio_serve, # pyright: ignore[reportUnknownVariableType] +) +from hypercorn.trio.run import ( + worker_serve as hyper_trio_serve, # pyright: ignore[reportUnknownVariableType] +) +from hypercorn.utils import ( + wrap_app as hyper_wrap_app, # pyright: ignore[reportUnknownVariableType] +) from starlette.requests import Request from starlette.websockets import WebSocket from typing_extensions import Self @@ -65,7 +76,7 @@ async def __aenter__(self) -> Self: async def __aexit__(self, *_: Any, **__: Any) -> None: """Shutdown the server.""" # 在 uvicorn.Server 的实现中,设置 should_exit 可以使得 server 任务结束 - assert self.should_exit is False, "The server has already exited." + assert not self.should_exit, "The server has already exited." self.should_exit = True await self._exit_stack.__aexit__(*_, **__) @@ -86,3 +97,129 @@ def contx_socket_url(self) -> httpx.URL: scheme="https" if config.is_ssl else "http", path="/", ) + + +class HypercornServer: + """An AsyncContext to launch and shutdown Hypercorn server automatically.""" + + def __init__(self, app: FastAPI, config: HyperConfig): # noqa: D107 + self.config = config + self.app = app + self.should_exit = anyio.Event() + + async def __aenter__(self) -> Self: + """Launch the server.""" + self._sockets = self.config.create_sockets() + self._exit_stack = AsyncExitStack() + + self.current_async_lib = sniffio.current_async_library() + + if self.current_async_lib == "asyncio": + serve_func = hyper_aio_serve # pyright: ignore[reportUnknownVariableType] + elif self.current_async_lib == "trio": + serve_func = hyper_trio_serve # pyright: ignore[reportUnknownVariableType] + else: + raise RuntimeError(f"Unsupported async library {self.current_async_lib!r}") + + async def serve() -> None: + # Implement ref: + # https://github.com/pgjones/hypercorn/blob/3fbd5f245e5dfeaba6ad852d9135d6a32b228d05/src/hypercorn/asyncio/__init__.py#L12-L46 + # https://github.com/pgjones/hypercorn/blob/3fbd5f245e5dfeaba6ad852d9135d6a32b228d05/src/hypercorn/trio/__init__.py#L14-L52 + await serve_func( + hyper_wrap_app( + self.app, # pyright: ignore[reportArgumentType] + self.config.wsgi_max_body_size, + mode=None, + ), + self.config, + shutdown_trigger=self.should_exit.wait, + ) + + task_group = await self._exit_stack.enter_async_context( + anyio.create_task_group() + ) + task_group.start_soon(serve, name=f"Hypercorn Server Task of {self}") + return self + + async def __aexit__(self, *_: Any, **__: Any) -> None: + """Shutdown the server.""" + assert not self.should_exit.is_set(), "The server has already exited." + self.should_exit.set() + await self._exit_stack.__aexit__(*_, **__) + + @property + def contx_socket_url(self) -> httpx.URL: + """If server is tcp socket, return the url of server. + + Note: The path of url is explicitly set to "/". + """ + config = self.config + + bind = config.bind[0] + if bind.startswith(("unix:", "fd://")): + raise RuntimeError("Only support tcp socket.") + + # refer to: https://docs.python.org/zh-cn/3/library/socket.html#socket-families + host, port = config.bind[0].split(":") + port = int(port) + + return httpx.URL( + host=host, + port=port, + scheme="https" if config.ssl_enabled else "http", + path="/", + ) + + +class TestServer: + """An AsyncContext to launch and shutdown Hypercorn or Uvicorn server automatically.""" + + def __init__( + self, + app: FastAPI, + host: str, + port: int, + server_type: Literal["uvicorn", "hypercorn"] = "hypercorn", + ): + """Only support ipv4 address. + + If use uvicorn, it only support asyncio backend. + """ + self.app = app + self.host = host + self.port = port + self.server_type = server_type + + if self.server_type == "hypercorn": + config = HyperConfig() + config.bind = f"{host}:{port}" + + self.config = config + self.server = HypercornServer(app, config) + else: + self.config = uvicorn.Config(app, host=host, port=port) + self.server = UvicornServer(self.config) + + async def __aenter__(self) -> Self: + """Launch the server.""" + if ( + self.server_type == "uvicorn" + and sniffio.current_async_library() != "asyncio" + ): + raise RuntimeError("Uvicorn server does not support trio backend.") + + self._exit_stack = AsyncExitStack() + await self._exit_stack.enter_async_context(self.server) + return self + + async def __aexit__(self, *_: Any, **__: Any) -> None: + """Shutdown the server.""" + await self._exit_stack.__aexit__(*_, **__) + + @property + def contx_socket_url(self) -> httpx.URL: + """If server is tcp socket, return the url of server. + + Note: The path of url is explicitly set to "/". + """ + return self.server.contx_socket_url From 4a7f4b9495cdfef78507f638eaf7b18667c5bee1 Mon Sep 17 00:00:00 2001 From: WSH032 <614337162@qq.com> Date: Thu, 4 Apr 2024 06:40:01 +0800 Subject: [PATCH 05/11] test: add `trio` tests and use `Hypercorn` as the server --- src/fastapi_proxy_lib/core/websocket.py | 2 +- tests/app/tool.py | 84 +++++++++++++++++++------ tests/conftest.py | 47 +++++++++----- tests/test_ws.py | 65 +++++++------------ 4 files changed, 121 insertions(+), 77 deletions(-) diff --git a/src/fastapi_proxy_lib/core/websocket.py b/src/fastapi_proxy_lib/core/websocket.py index 6e36213..f5eb113 100644 --- a/src/fastapi_proxy_lib/core/websocket.py +++ b/src/fastapi_proxy_lib/core/websocket.py @@ -42,7 +42,7 @@ DEFAULT_MAX_MESSAGE_SIZE_BYTES, DEFAULT_QUEUE_SIZE, ) -except ImportError: +except ImportError: # pragma: no cover # ref: https://github.com/frankie567/httpx-ws/blob/b2135792141b71551b022ff0d76542a0263a890c/httpx_ws/_api.py#L31-L34 DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS = ( # pyright: ignore[reportConstantRedefinition] 20.0 diff --git a/tests/app/tool.py b/tests/app/tool.py index 75fbf07..dbf65dc 100644 --- a/tests/app/tool.py +++ b/tests/app/tool.py @@ -2,7 +2,7 @@ from contextlib import AsyncExitStack from dataclasses import dataclass -from typing import Any, Literal, TypedDict, Union +from typing import Any, Literal, Optional, TypedDict, Union import anyio import httpx @@ -11,17 +11,20 @@ from fastapi import FastAPI from hypercorn import Config as HyperConfig from hypercorn.asyncio.run import ( - worker_serve as hyper_aio_serve, # pyright: ignore[reportUnknownVariableType] + worker_serve as hyper_aio_worker_serve, # pyright: ignore[reportUnknownVariableType] ) from hypercorn.trio.run import ( - worker_serve as hyper_trio_serve, # pyright: ignore[reportUnknownVariableType] + worker_serve as hyper_trio_worker_serve, # pyright: ignore[reportUnknownVariableType] +) +from hypercorn.utils import ( + repr_socket_addr, # pyright: ignore[reportUnknownVariableType] ) from hypercorn.utils import ( wrap_app as hyper_wrap_app, # pyright: ignore[reportUnknownVariableType] ) from starlette.requests import Request from starlette.websockets import WebSocket -from typing_extensions import Self +from typing_extensions import Self, assert_never ServerRecvRequestsTypes = Union[Request, WebSocket] @@ -55,7 +58,7 @@ def get_request(self) -> ServerRecvRequestsTypes: return server_recv_request -class UvicornServer(uvicorn.Server): +class _UvicornServer(uvicorn.Server): """subclass of `uvicorn.Server` which can use AsyncContext to launch and shutdown automatically.""" async def __aenter__(self) -> Self: @@ -89,7 +92,9 @@ def contx_socket_url(self) -> httpx.URL: config = self.config if config.fd is not None or config.uds is not None: raise RuntimeError("Only support tcp socket.") - # refer to: https://docs.python.org/zh-cn/3/library/socket.html#socket-families + + # Implement ref: + # https://github.com/encode/uvicorn/blob/a2219eb2ed2bbda4143a0fb18c4b0578881b1ae8/uvicorn/server.py#L201-L220 host, port = self._socket.getsockname()[:2] return httpx.URL( host=host, @@ -99,25 +104,42 @@ def contx_socket_url(self) -> httpx.URL: ) -class HypercornServer: +class _HypercornServer: """An AsyncContext to launch and shutdown Hypercorn server automatically.""" - def __init__(self, app: FastAPI, config: HyperConfig): # noqa: D107 + def __init__(self, app: FastAPI, config: HyperConfig): self.config = config self.app = app self.should_exit = anyio.Event() async def __aenter__(self) -> Self: """Launch the server.""" - self._sockets = self.config.create_sockets() self._exit_stack = AsyncExitStack() self.current_async_lib = sniffio.current_async_library() if self.current_async_lib == "asyncio": - serve_func = hyper_aio_serve # pyright: ignore[reportUnknownVariableType] + serve_func = ( # pyright: ignore[reportUnknownVariableType] + hyper_aio_worker_serve + ) + + # Implement ref: + # https://github.com/pgjones/hypercorn/blob/3fbd5f245e5dfeaba6ad852d9135d6a32b228d05/src/hypercorn/asyncio/run.py#L89-L90 + self._sockets = self.config.create_sockets() + elif self.current_async_lib == "trio": - serve_func = hyper_trio_serve # pyright: ignore[reportUnknownVariableType] + serve_func = ( # pyright: ignore[reportUnknownVariableType] + hyper_trio_worker_serve + ) + + # Implement ref: + # https://github.com/pgjones/hypercorn/blob/3fbd5f245e5dfeaba6ad852d9135d6a32b228d05/src/hypercorn/trio/run.py#L51-L56 + self._sockets = self.config.create_sockets() + for sock in self._sockets.secure_sockets: + sock.listen(self.config.backlog) + for sock in self._sockets.insecure_sockets: + sock.listen(self.config.backlog) + else: raise RuntimeError(f"Unsupported async library {self.current_async_lib!r}") @@ -133,6 +155,7 @@ async def serve() -> None: ), self.config, shutdown_trigger=self.should_exit.wait, + sockets=self._sockets, ) task_group = await self._exit_stack.enter_async_context( @@ -154,13 +177,32 @@ def contx_socket_url(self) -> httpx.URL: Note: The path of url is explicitly set to "/". """ config = self.config + sockets = self._sockets + + # Implement ref: + # https://github.com/pgjones/hypercorn/blob/3fbd5f245e5dfeaba6ad852d9135d6a32b228d05/src/hypercorn/asyncio/run.py#L112-L149 + # https://github.com/pgjones/hypercorn/blob/3fbd5f245e5dfeaba6ad852d9135d6a32b228d05/src/hypercorn/trio/run.py#L61-L82 + + # We only run on one socket each time, + # so we raise `RuntimeError` to avoid other unknown errors during testing. + if sockets.insecure_sockets: + if len(sockets.insecure_sockets) > 1: + raise RuntimeError("Hypercorn test: Multiple insecure_sockets found.") + socket = sockets.insecure_sockets[0] + elif sockets.secure_sockets: + if len(sockets.secure_sockets) > 1: + raise RuntimeError("Hypercorn test: secure_sockets sockets found.") + socket = sockets.secure_sockets[0] + else: + raise RuntimeError("Hypercorn test: No socket found.") - bind = config.bind[0] + bind = repr_socket_addr(socket.family, socket.getsockname()) if bind.startswith(("unix:", "fd://")): raise RuntimeError("Only support tcp socket.") - # refer to: https://docs.python.org/zh-cn/3/library/socket.html#socket-families - host, port = config.bind[0].split(":") + # Implement ref: + # https://docs.python.org/zh-cn/3/library/socket.html#socket-families + host, port = bind.split(":") port = int(port) return httpx.URL( @@ -179,12 +221,16 @@ def __init__( app: FastAPI, host: str, port: int, - server_type: Literal["uvicorn", "hypercorn"] = "hypercorn", + server_type: Optional[Literal["uvicorn", "hypercorn"]] = None, ): """Only support ipv4 address. If use uvicorn, it only support asyncio backend. + + If `host` == 0, then use random port. """ + server_type = server_type if server_type is not None else "hypercorn" + self.app = app self.host = host self.port = port @@ -195,10 +241,12 @@ def __init__( config.bind = f"{host}:{port}" self.config = config - self.server = HypercornServer(app, config) - else: + self.server = _HypercornServer(app, config) + elif self.server_type == "uvicorn": self.config = uvicorn.Config(app, host=host, port=port) - self.server = UvicornServer(self.config) + self.server = _UvicornServer(self.config) + else: + assert_never(self.server_type) async def __aenter__(self) -> Self: """Launch the server.""" diff --git a/tests/conftest.py b/tests/conftest.py index 147793e..39fbb30 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,12 +15,13 @@ Callable, Coroutine, Literal, + Optional, Protocol, ) import pytest -import uvicorn from asgi_lifespan import LifespanManager +from fastapi import FastAPI from fastapi_proxy_lib.fastapi.app import ( forward_http_app, reverse_http_app, @@ -30,7 +31,7 @@ from .app.echo_http_app import get_app as get_http_test_app from .app.echo_ws_app import get_app as get_ws_test_app -from .app.tool import AppDataclass4Test, UvicornServer +from .app.tool import AppDataclass4Test, TestServer # ASGI types. # Copied from: https://github.com/florimondmanca/asgi-lifespan/blob/fbb0f440337314be97acaae1a3c0c7a2ec8298dd/src/asgi_lifespan/_types.py @@ -61,17 +62,28 @@ class LifeAppDataclass4Test(AppDataclass4Test): """The lifespan of app will be managed automatically by pytest.""" -class UvicornServerFixture(Protocol): # noqa: D101 +class TestServerFixture(Protocol): # noqa: D101 def __call__( # noqa: D102 - self, config: uvicorn.Config - ) -> Coroutine[None, None, UvicornServer]: ... + self, + app: FastAPI, + host: str, + port: int, + server_type: Optional[Literal["uvicorn", "hypercorn"]] = None, + ) -> Coroutine[None, None, TestServer]: ... # https://anyio.readthedocs.io/en/stable/testing.html#specifying-the-backends-to-run-on -@pytest.fixture() -def anyio_backend() -> Literal["asyncio"]: +@pytest.fixture( + params=[ + pytest.param(("asyncio", {"use_uvloop": False}), id="asyncio"), + pytest.param( + ("trio", {"restrict_keyboard_interrupt_to_checkpoints": True}), id="trio" + ), + ], +) +def anyio_backend(request: pytest.FixtureRequest): """Specify the async backend for `pytest.mark.anyio`.""" - return "asyncio" + return request.param @pytest.fixture() @@ -191,17 +203,22 @@ def reverse_ws_app_fct( @pytest.fixture() -async def uvicorn_server_fixture() -> AsyncIterator[UvicornServerFixture]: - """Fixture for UvicornServer. +async def test_server_fixture() -> AsyncIterator[TestServerFixture]: + """Fixture for TestServer. Will launch and shutdown automatically. """ async with AsyncExitStack() as exit_stack: - async def uvicorn_server_fct(config: uvicorn.Config) -> UvicornServer: - uvicorn_server = await exit_stack.enter_async_context( - UvicornServer(config=config) + async def test_server_fct( + app: FastAPI, + host: str, + port: int, + server_type: Optional[Literal["uvicorn", "hypercorn"]] = None, + ) -> TestServer: + test_server = await exit_stack.enter_async_context( + TestServer(app=app, host=host, port=port, server_type=server_type) ) - return uvicorn_server + return test_server - yield uvicorn_server_fct + yield test_server_fct diff --git a/tests/test_ws.py b/tests/test_ws.py index 5e33412..3338227 100644 --- a/tests/test_ws.py +++ b/tests/test_ws.py @@ -3,55 +3,43 @@ from contextlib import AsyncExitStack from multiprocessing import Process, Queue -from typing import Any, Dict, Literal, Optional +from typing import Any, Dict, Optional import anyio import httpx import httpx_ws import pytest -import uvicorn from fastapi_proxy_lib.fastapi.app import reverse_ws_app as get_reverse_ws_app from httpx_ws import aconnect_ws from starlette import websockets as starlette_websockets_module from typing_extensions import override from .app.echo_ws_app import get_app as get_ws_test_app -from .app.tool import UvicornServer -from .conftest import UvicornServerFixture +from .app.tool import TestServer +from .conftest import TestServerFixture from .tool import ( AbstractTestProxy, Tool4TestFixture, ) DEFAULT_HOST = "127.0.0.1" -DEFAULT_PORT = 0 - -# WS_BACKENDS_NEED_BE_TESTED = ("websockets", "wsproto") -# # FIXME: wsproto 有问题,暂时不测试 -# # ConnectionResetError: [WinError 10054] 远程主机强迫关闭了一个现有的连接。 -# # https://github.com/encode/uvicorn/discussions/2105 -WS_BACKENDS_NEED_BE_TESTED = ("websockets",) +DEFAULT_PORT = 0 # random port # https://www.python-httpx.org/advanced/#http-proxying NO_PROXIES: Dict[Any, Any] = {"all://": None} -def _subprocess_run_echo_ws_uvicorn_server(queue: "Queue[str]", **kwargs: Any): +def _subprocess_run_echo_ws_server(queue: "Queue[str]"): """Run echo ws app in subprocess. Args: queue: The queue for subprocess to put the url of echo ws app. After the server is started, the url will be put into the queue. - **kwargs: The kwargs for `uvicorn.Config` """ - default_kwargs = { - "app": get_ws_test_app().app, - "port": DEFAULT_PORT, - "host": DEFAULT_HOST, - } - default_kwargs.update(kwargs) - target_ws_server = UvicornServer( - uvicorn.Config(**default_kwargs), # pyright: ignore[reportArgumentType] + target_ws_server = TestServer( + app=get_ws_test_app().app, + host=DEFAULT_HOST, + port=DEFAULT_PORT, ) async def run(): @@ -104,21 +92,18 @@ class TestReverseWsProxy(AbstractTestProxy): """For testing reverse websocket proxy.""" @override - @pytest.fixture(params=WS_BACKENDS_NEED_BE_TESTED) + @pytest.fixture() async def tool_4_test_fixture( # pyright: ignore[reportIncompatibleMethodOverride] self, - uvicorn_server_fixture: UvicornServerFixture, - request: pytest.FixtureRequest, + test_server_fixture: TestServerFixture, ) -> Tool4TestFixture: """目标服务器请参考`tests.app.echo_ws_app.get_app`.""" echo_ws_test_model = get_ws_test_app() echo_ws_app = echo_ws_test_model.app echo_ws_get_request = echo_ws_test_model.get_request - target_ws_server = await uvicorn_server_fixture( - uvicorn.Config( - echo_ws_app, port=DEFAULT_PORT, host=DEFAULT_HOST, ws=request.param - ), + target_ws_server = await test_server_fixture( + app=echo_ws_app, port=DEFAULT_PORT, host=DEFAULT_HOST ) target_server_base_url = str(target_ws_server.contx_socket_url) @@ -129,10 +114,8 @@ async def tool_4_test_fixture( # pyright: ignore[reportIncompatibleMethodOverri client=client_for_conn_to_target_server, base_url=target_server_base_url ) - proxy_ws_server = await uvicorn_server_fixture( - uvicorn.Config( - reverse_ws_app, port=DEFAULT_PORT, host=DEFAULT_HOST, ws=request.param - ), + proxy_ws_server = await test_server_fixture( + app=reverse_ws_app, port=DEFAULT_PORT, host=DEFAULT_HOST ) proxy_server_base_url = str(proxy_ws_server.contx_socket_url) @@ -197,7 +180,7 @@ async def test_ws_proxy(self, tool_4_test_fixture: Tool4TestFixture) -> None: client_for_conn_to_proxy_server, ) as ws: pass - # uvicorn 服务器在未调用`websocket.accept()`之前调用了`websocket.close()`,会发生403 + # Starlette 在未调用`websocket.accept()`之前调用了`websocket.close()`,会发生403 assert exce.value.response.status_code == 403 ########## 客户端突然关闭时,服务器应该收到1011 ########## @@ -245,10 +228,7 @@ async def test_ws_proxy(self, tool_4_test_fixture: Tool4TestFixture) -> None: # FIXME: 调查为什么收到关闭代码需要40s @pytest.mark.timeout(60) @pytest.mark.anyio() - @pytest.mark.parametrize("ws_backend", WS_BACKENDS_NEED_BE_TESTED) - async def test_target_server_shutdown_abnormally( - self, ws_backend: Literal["websockets", "wsproto"] - ) -> None: + async def test_target_server_shutdown_abnormally(self) -> None: """测试因为目标服务器突然断连导致的,ws桥接异常关闭. 需要在 60s 内向客户端发送 1011 关闭代码. @@ -256,9 +236,8 @@ async def test_target_server_shutdown_abnormally( subprocess_queue: "Queue[str]" = Queue() target_ws_server_subprocess = Process( - target=_subprocess_run_echo_ws_uvicorn_server, + target=_subprocess_run_echo_ws_server, args=(subprocess_queue,), - kwargs={"port": DEFAULT_PORT, "host": DEFAULT_HOST, "ws": ws_backend}, ) target_ws_server_subprocess.start() @@ -273,10 +252,10 @@ async def test_target_server_shutdown_abnormally( client=client_for_conn_to_target_server, base_url=target_server_base_url ) - async with UvicornServer( - uvicorn.Config( - reverse_ws_app, port=DEFAULT_PORT, host=DEFAULT_HOST, ws=ws_backend - ) + async with TestServer( + app=reverse_ws_app, + port=DEFAULT_PORT, + host=DEFAULT_HOST, ) as proxy_ws_server: proxy_server_base_url = str(proxy_ws_server.contx_socket_url) From 8f0ba1fe7aac110d77731e3d591e905321f6f32b Mon Sep 17 00:00:00 2001 From: WSH032 <614337162@qq.com> Date: Thu, 4 Apr 2024 08:13:25 +0800 Subject: [PATCH 06/11] test: fix the warnings in the tests by not using deprecated APIs from `httpx` --- tests/app/tool.py | 2 +- tests/conftest.py | 22 ++++++++--------- tests/test_core_lib.py | 5 +++- tests/test_http.py | 43 +++++++++++++++++++++++++-------- tests/test_ws.py | 55 ++++++++++++++++-------------------------- 5 files changed, 70 insertions(+), 57 deletions(-) diff --git a/tests/app/tool.py b/tests/app/tool.py index dbf65dc..9236e91 100644 --- a/tests/app/tool.py +++ b/tests/app/tool.py @@ -213,7 +213,7 @@ def contx_socket_url(self) -> httpx.URL: ) -class TestServer: +class AutoServer: """An AsyncContext to launch and shutdown Hypercorn or Uvicorn server automatically.""" def __init__( diff --git a/tests/conftest.py b/tests/conftest.py index 39fbb30..e60ec07 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,7 +31,7 @@ from .app.echo_http_app import get_app as get_http_test_app from .app.echo_ws_app import get_app as get_ws_test_app -from .app.tool import AppDataclass4Test, TestServer +from .app.tool import AppDataclass4Test, AutoServer # ASGI types. # Copied from: https://github.com/florimondmanca/asgi-lifespan/blob/fbb0f440337314be97acaae1a3c0c7a2ec8298dd/src/asgi_lifespan/_types.py @@ -62,14 +62,14 @@ class LifeAppDataclass4Test(AppDataclass4Test): """The lifespan of app will be managed automatically by pytest.""" -class TestServerFixture(Protocol): # noqa: D101 +class AutoServerFixture(Protocol): # noqa: D101 def __call__( # noqa: D102 self, app: FastAPI, host: str, port: int, server_type: Optional[Literal["uvicorn", "hypercorn"]] = None, - ) -> Coroutine[None, None, TestServer]: ... + ) -> Coroutine[None, None, AutoServer]: ... # https://anyio.readthedocs.io/en/stable/testing.html#specifying-the-backends-to-run-on @@ -203,22 +203,22 @@ def reverse_ws_app_fct( @pytest.fixture() -async def test_server_fixture() -> AsyncIterator[TestServerFixture]: - """Fixture for TestServer. +async def auto_server_fixture() -> AsyncIterator[AutoServerFixture]: + """Fixture for AutoServer. Will launch and shutdown automatically. """ async with AsyncExitStack() as exit_stack: - async def test_server_fct( + async def auto_server_fct( app: FastAPI, host: str, port: int, server_type: Optional[Literal["uvicorn", "hypercorn"]] = None, - ) -> TestServer: - test_server = await exit_stack.enter_async_context( - TestServer(app=app, host=host, port=port, server_type=server_type) + ) -> AutoServer: + auto_server = await exit_stack.enter_async_context( + AutoServer(app=app, host=host, port=port, server_type=server_type) ) - return test_server + return auto_server - yield test_server_fct + yield auto_server_fct diff --git a/tests/test_core_lib.py b/tests/test_core_lib.py index da52e58..b214b2e 100644 --- a/tests/test_core_lib.py +++ b/tests/test_core_lib.py @@ -70,7 +70,10 @@ async def _() -> JSONResponse: # } # } - client = httpx.AsyncClient(app=app, base_url="http://www.example.com") + client = httpx.AsyncClient( + transport=httpx.ASGITransport(app), # pyright: ignore[reportArgumentType] + base_url="http://www.example.com", + ) resp = await client.get("http://www.example.com/exception") assert resp.status_code == 0 assert resp.json()["detail"] == test_err_msg diff --git a/tests/test_http.py b/tests/test_http.py index e2c649c..71b56ac 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -32,7 +32,10 @@ async def tool_4_test_fixture( # pyright: ignore[reportIncompatibleMethodOverri ) -> Tool4TestFixture: """目标服务器请参考`tests.app.echo_http_app.get_app`.""" client_for_conn_to_target_server = httpx.AsyncClient( - app=echo_http_test_model.app, base_url=DEFAULT_TARGET_SERVER_BASE_URL + transport=httpx.ASGITransport( + echo_http_test_model.app # pyright: ignore[reportArgumentType] + ), + base_url=DEFAULT_TARGET_SERVER_BASE_URL, ) reverse_http_app = await reverse_http_app_fct( @@ -41,7 +44,10 @@ async def tool_4_test_fixture( # pyright: ignore[reportIncompatibleMethodOverri ) client_for_conn_to_proxy_server = httpx.AsyncClient( - app=reverse_http_app, base_url=DEFAULT_PROXY_SERVER_BASE_URL + transport=httpx.ASGITransport( + reverse_http_app # pyright: ignore[reportArgumentType] + ), + base_url=DEFAULT_PROXY_SERVER_BASE_URL, ) get_request = echo_http_test_model.get_request @@ -198,7 +204,10 @@ async def test_bad_url_request( ) client_for_conn_to_proxy_server = httpx.AsyncClient( - app=reverse_http_app, base_url=DEFAULT_PROXY_SERVER_BASE_URL + transport=httpx.ASGITransport( + reverse_http_app # pyright: ignore[reportArgumentType] + ), + base_url=DEFAULT_PROXY_SERVER_BASE_URL, ) r = await client_for_conn_to_proxy_server.get(DEFAULT_PROXY_SERVER_BASE_URL) @@ -233,9 +242,9 @@ async def test_cookie_leakage( assert not client_for_conn_to_proxy_server.cookies # check if cookie is not leaked + client_for_conn_to_proxy_server.cookies.set("a", "b") r = await client_for_conn_to_proxy_server.get( - proxy_server_base_url + "get/cookies", - cookies={"a": "b"}, + proxy_server_base_url + "get/cookies" ) assert "foo" not in r.json() # not leaked assert r.json()["a"] == "b" # send cookies normally @@ -252,7 +261,10 @@ async def tool_4_test_fixture( # pyright: ignore[reportIncompatibleMethodOverri ) -> Tool4TestFixture: """目标服务器请参考`tests.app.echo_http_app.get_app`.""" client_for_conn_to_target_server = httpx.AsyncClient( - app=echo_http_test_model.app, base_url=DEFAULT_TARGET_SERVER_BASE_URL + transport=httpx.ASGITransport( + echo_http_test_model.app # pyright: ignore[reportArgumentType] + ), + base_url=DEFAULT_TARGET_SERVER_BASE_URL, ) forward_http_app = await forward_http_app_fct( @@ -260,7 +272,10 @@ async def tool_4_test_fixture( # pyright: ignore[reportIncompatibleMethodOverri ) client_for_conn_to_proxy_server = httpx.AsyncClient( - app=forward_http_app, base_url=DEFAULT_PROXY_SERVER_BASE_URL + transport=httpx.ASGITransport( + forward_http_app # pyright: ignore[reportArgumentType] + ), + base_url=DEFAULT_PROXY_SERVER_BASE_URL, ) get_request = echo_http_test_model.get_request @@ -310,7 +325,10 @@ async def test_bad_url_request( ) client_for_conn_to_proxy_server = httpx.AsyncClient( - app=forward_http_app, base_url=DEFAULT_PROXY_SERVER_BASE_URL + transport=httpx.ASGITransport( + forward_http_app # pyright: ignore[reportArgumentType] + ), + base_url=DEFAULT_PROXY_SERVER_BASE_URL, ) # 错误的无法发出请求的URL @@ -356,7 +374,10 @@ async def connect_error_mock_handler( ) client_for_conn_to_proxy_server = httpx.AsyncClient( - app=forward_http_app, base_url=DEFAULT_PROXY_SERVER_BASE_URL + transport=httpx.ASGITransport( + forward_http_app # pyright: ignore[reportArgumentType] + ), + base_url=DEFAULT_PROXY_SERVER_BASE_URL, ) r = await client_for_conn_to_proxy_server.get( @@ -385,7 +406,9 @@ async def test_denial_http2( ) client_for_conn_to_proxy_server = httpx.AsyncClient( - app=forward_http_app, + transport=httpx.ASGITransport( + forward_http_app + ), # pyright: ignore[reportArgumentType] base_url=proxy_server_base_url, http2=True, http1=False, diff --git a/tests/test_ws.py b/tests/test_ws.py index 3338227..8880cc6 100644 --- a/tests/test_ws.py +++ b/tests/test_ws.py @@ -3,7 +3,7 @@ from contextlib import AsyncExitStack from multiprocessing import Process, Queue -from typing import Any, Dict, Optional +from typing import Any, Dict import anyio import httpx @@ -15,8 +15,8 @@ from typing_extensions import override from .app.echo_ws_app import get_app as get_ws_test_app -from .app.tool import TestServer -from .conftest import TestServerFixture +from .app.tool import AutoServer +from .conftest import AutoServerFixture from .tool import ( AbstractTestProxy, Tool4TestFixture, @@ -25,7 +25,8 @@ DEFAULT_HOST = "127.0.0.1" DEFAULT_PORT = 0 # random port -# https://www.python-httpx.org/advanced/#http-proxying +# https://www.python-httpx.org/advanced/proxies/ +# NOTE: Foce to connect directly, avoid using system proxies NO_PROXIES: Dict[Any, Any] = {"all://": None} @@ -36,7 +37,7 @@ def _subprocess_run_echo_ws_server(queue: "Queue[str]"): queue: The queue for subprocess to put the url of echo ws app. After the server is started, the url will be put into the queue. """ - target_ws_server = TestServer( + target_ws_server = AutoServer( app=get_ws_test_app().app, host=DEFAULT_HOST, port=DEFAULT_PORT, @@ -55,29 +56,22 @@ async def run(): def _subprocess_run_httpx_ws( queue: "Queue[str]", - kwargs_async_client: Optional[Dict[str, Any]] = None, - kwargs_aconnect_ws: Optional[Dict[str, Any]] = None, + aconnect_ws_url: str, ): """Run aconnect_ws in subprocess. Args: queue: The queue for subprocess to put something for flag of ws connection established. - kwargs_async_client: The kwargs for `httpx.AsyncClient` - kwargs_aconnect_ws: The kwargs for `httpx_ws.aconnect_ws` + aconnect_ws_url: The websocket url for aconnect_ws. """ - kwargs_async_client = kwargs_async_client or {} - kwargs_aconnect_ws = kwargs_aconnect_ws or {} - - kwargs_async_client.pop("proxies", None) - kwargs_aconnect_ws.pop("client", None) async def run(): _exit_stack = AsyncExitStack() - _temp_client = httpx.AsyncClient(proxies=NO_PROXIES, **kwargs_async_client) + _temp_client = httpx.AsyncClient(mounts=NO_PROXIES) _ = await _exit_stack.enter_async_context( aconnect_ws( client=_temp_client, - **kwargs_aconnect_ws, + url=aconnect_ws_url, ) ) queue.put("done") @@ -95,32 +89,32 @@ class TestReverseWsProxy(AbstractTestProxy): @pytest.fixture() async def tool_4_test_fixture( # pyright: ignore[reportIncompatibleMethodOverride] self, - test_server_fixture: TestServerFixture, + auto_server_fixture: AutoServerFixture, ) -> Tool4TestFixture: """目标服务器请参考`tests.app.echo_ws_app.get_app`.""" echo_ws_test_model = get_ws_test_app() echo_ws_app = echo_ws_test_model.app echo_ws_get_request = echo_ws_test_model.get_request - target_ws_server = await test_server_fixture( + target_ws_server = await auto_server_fixture( app=echo_ws_app, port=DEFAULT_PORT, host=DEFAULT_HOST ) target_server_base_url = str(target_ws_server.contx_socket_url) - client_for_conn_to_target_server = httpx.AsyncClient(proxies=NO_PROXIES) + client_for_conn_to_target_server = httpx.AsyncClient(mounts=NO_PROXIES) reverse_ws_app = get_reverse_ws_app( client=client_for_conn_to_target_server, base_url=target_server_base_url ) - proxy_ws_server = await test_server_fixture( + proxy_ws_server = await auto_server_fixture( app=reverse_ws_app, port=DEFAULT_PORT, host=DEFAULT_HOST ) proxy_server_base_url = str(proxy_ws_server.contx_socket_url) - client_for_conn_to_proxy_server = httpx.AsyncClient(proxies=NO_PROXIES) + client_for_conn_to_proxy_server = httpx.AsyncClient(mounts=NO_PROXIES) return Tool4TestFixture( client_for_conn_to_target_server=client_for_conn_to_target_server, @@ -189,18 +183,11 @@ async def test_ws_proxy(self, tool_4_test_fixture: Tool4TestFixture) -> None: # 是因为这里已经有现成的target server,放在这里测试可以节省启动服务器时间 aconnect_ws_subprocess_queue: "Queue[str]" = Queue() - - kwargs_async_client = {"proxies": NO_PROXIES} - kwargs_aconnect_ws = {"url": proxy_server_base_url + "do_nothing"} - kwargs = { - "kwargs_async_client": kwargs_async_client, - "kwargs_aconnect_ws": kwargs_aconnect_ws, - } + aconnect_ws_url = proxy_server_base_url + "do_nothing" aconnect_ws_subprocess = Process( target=_subprocess_run_httpx_ws, - args=(aconnect_ws_subprocess_queue,), - kwargs=kwargs, + args=(aconnect_ws_subprocess_queue, aconnect_ws_url), ) aconnect_ws_subprocess.start() @@ -246,13 +233,13 @@ async def test_target_server_shutdown_abnormally(self) -> None: await anyio.sleep(0.1) target_server_base_url = subprocess_queue.get() - client_for_conn_to_target_server = httpx.AsyncClient(proxies=NO_PROXIES) + client_for_conn_to_target_server = httpx.AsyncClient(mounts=NO_PROXIES) reverse_ws_app = get_reverse_ws_app( client=client_for_conn_to_target_server, base_url=target_server_base_url ) - async with TestServer( + async with AutoServer( app=reverse_ws_app, port=DEFAULT_PORT, host=DEFAULT_HOST, @@ -261,10 +248,10 @@ async def test_target_server_shutdown_abnormally(self) -> None: async with aconnect_ws( proxy_server_base_url + "do_nothing", - httpx.AsyncClient(proxies=NO_PROXIES), + httpx.AsyncClient(mounts=NO_PROXIES), ) as ws0, aconnect_ws( proxy_server_base_url + "do_nothing", - httpx.AsyncClient(proxies=NO_PROXIES), + httpx.AsyncClient(mounts=NO_PROXIES), ) as ws1: # force shutdown target server target_ws_server_subprocess.terminate() From 43466157e56af6a47011c9facf420c7aab0a0dd6 Mon Sep 17 00:00:00 2001 From: WSH032 <614337162@qq.com> Date: Thu, 4 Apr 2024 08:18:32 +0800 Subject: [PATCH 07/11] docs: remove deprecated httpx APIs in example --- docs/Usage/FastAPI-Helper.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Usage/FastAPI-Helper.md b/docs/Usage/FastAPI-Helper.md index 45f1ebe..ccdee99 100644 --- a/docs/Usage/FastAPI-Helper.md +++ b/docs/Usage/FastAPI-Helper.md @@ -29,7 +29,7 @@ app = reverse_http_app(client=client, base_url=base_url) ``` 1. You can pass `httpx.AsyncClient` instance: - - if you want to customize the arguments, e.g. `httpx.AsyncClient(proxies={})` + - if you want to customize the arguments, e.g. `httpx.AsyncClient(http2=True)` - if you want to reuse the connection pool of `httpx.AsyncClient` --- Or you can pass `None`(The default value), then `fastapi-proxy-lib` will create a new `httpx.AsyncClient` instance for you. From 326f8b0cf2019eb7e3d34e69da7179e523680c43 Mon Sep 17 00:00:00 2001 From: WSH032 <614337162@qq.com> Date: Fri, 5 Apr 2024 23:44:49 +0800 Subject: [PATCH 08/11] feat: enhance the robustness of WebSocket error handling and remove the HTTP version restriction --- src/fastapi_proxy_lib/core/_tool.py | 38 ----- src/fastapi_proxy_lib/core/http.py | 11 -- src/fastapi_proxy_lib/core/websocket.py | 196 +++++++++++++----------- tests/app/echo_ws_app.py | 7 +- tests/app/tool.py | 1 + tests/test_ws.py | 25 ++- 6 files changed, 133 insertions(+), 145 deletions(-) diff --git a/src/fastapi_proxy_lib/core/_tool.py b/src/fastapi_proxy_lib/core/_tool.py index 1cfd386..5604dec 100644 --- a/src/fastapi_proxy_lib/core/_tool.py +++ b/src/fastapi_proxy_lib/core/_tool.py @@ -1,13 +1,11 @@ """The utils tools for both http proxy and websocket proxy.""" import ipaddress -import logging import warnings from functools import lru_cache from textwrap import dedent from typing import ( Any, - Iterable, Mapping, Optional, Protocol, @@ -17,7 +15,6 @@ ) import httpx -from starlette import status from starlette.background import BackgroundTask as BackgroundTask_t from starlette.datastructures import ( Headers as StarletteHeaders, @@ -26,13 +23,11 @@ MutableHeaders as StarletteMutableHeaders, ) from starlette.responses import JSONResponse -from starlette.types import Scope from typing_extensions import deprecated, overload __all__ = ( "check_base_url", "return_err_msg_response", - "check_http_version", "BaseURLError", "ErrMsg", "ErrRseponseJson", @@ -129,10 +124,6 @@ class _RejectedProxyRequestError(RuntimeError): """Should be raised when reject proxy request.""" -class _UnsupportedHttpVersionError(RuntimeError): - """Unsupported http version.""" - - #################### Tools #################### @@ -337,35 +328,6 @@ def return_err_msg_response( ) -def check_http_version( - scope: Scope, supported_versions: Iterable[str] -) -> Union[JSONResponse, None]: - """Check whether the http version of scope is in supported_versions. - - Args: - scope: asgi scope dict. - supported_versions: The supported http versions. - - Returns: - If the http version of scope is not in supported_versions, return a JSONResponse with status_code=505, - else return None. - """ - # https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope - # https://asgi.readthedocs.io/en/latest/specs/www.html#websocket-connection-scope - http_version: str = scope.get("http_version", "") - # 如果明确指定了http版本(即不是""),但不在支持的版本内,则返回505 - if http_version not in supported_versions and http_version != "": - error = _UnsupportedHttpVersionError( - f"The request http version is {http_version}, but we only support {supported_versions}." - ) - # TODO: 或许可以logging记录下 scope.get("client") 的值 - return return_err_msg_response( - error, - status_code=status.HTTP_505_HTTP_VERSION_NOT_SUPPORTED, - logger=logging.info, - ) - - def default_proxy_filter(url: httpx.URL) -> Union[None, str]: """Filter by host. diff --git a/src/fastapi_proxy_lib/core/http.py b/src/fastapi_proxy_lib/core/http.py index fab3316..d9e1b09 100644 --- a/src/fastapi_proxy_lib/core/http.py +++ b/src/fastapi_proxy_lib/core/http.py @@ -31,7 +31,6 @@ _RejectedProxyRequestError, # pyright: ignore [reportPrivateUsage] # 允许使用本项目内部的私有成员 change_necessary_client_header_for_httpx, check_base_url, - check_http_version, return_err_msg_response, warn_for_none_filter, ) @@ -81,10 +80,6 @@ class _ReverseProxyServerError(RuntimeError): _NON_REQUEST_BODY_METHODS = ("GET", "HEAD", "OPTIONS", "TRACE") """The http methods that should not contain request body.""" -# https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope -SUPPORTED_HTTP_VERSIONS = ("1.0", "1.1") -"""The http versions that we supported now. It depends on `httpx`.""" - # https://www.python-httpx.org/exceptions/ _400_ERROR_NEED_TO_BE_CATCHED_IN_FORWARD_PROXY = ( httpx.InvalidURL, # 解析url时出错 @@ -227,8 +222,6 @@ async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOv ) -> StarletteResponse: """Change request headers and send request to target url. - - The http version of request must be in [`SUPPORTED_HTTP_VERSIONS`][fastapi_proxy_lib.core.http.SUPPORTED_HTTP_VERSIONS]. - Args: request: the original client request. target_url: target url that request will be sent to. @@ -239,10 +232,6 @@ async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOv client = self.client follow_redirects = self.follow_redirects - check_result = check_http_version(request.scope, SUPPORTED_HTTP_VERSIONS) - if check_result is not None: - return check_result - # 将请求头中的host字段改为目标url的host # 同时强制移除"keep-alive"字段和添加"keep-alive"值到"connection"字段中保持连接 require_close, proxy_header = _change_client_header( diff --git a/src/fastapi_proxy_lib/core/websocket.py b/src/fastapi_proxy_lib/core/websocket.py index f5eb113..99ded30 100644 --- a/src/fastapi_proxy_lib/core/websocket.py +++ b/src/fastapi_proxy_lib/core/websocket.py @@ -2,6 +2,7 @@ import logging import warnings +from collections import deque from contextlib import AsyncExitStack from textwrap import dedent from typing import ( @@ -15,10 +16,10 @@ ) import anyio +import anyio.abc import httpx import httpx_ws import starlette.websockets as starlette_ws -from exceptiongroup import ExceptionGroup from starlette import status as starlette_status from starlette.responses import Response as StarletteResponse from starlette.responses import StreamingResponse @@ -31,7 +32,6 @@ from ._tool import ( change_necessary_client_header_for_httpx, check_base_url, - check_http_version, ) # XXX: because these variables are private, we have to use try-except to avoid errors @@ -81,16 +81,13 @@ #################### Data Model #################### -_WsExceptionGroupType = ExceptionGroup[ - Union[starlette_ws.WebSocketDisconnect, httpx_ws.WebSocketDisconnect, Exception] +_WsDisconnectType = Union[ + starlette_ws.WebSocketDisconnect, httpx_ws.WebSocketDisconnect ] -#################### Constant #################### - +_WsDisconnectErrors = (starlette_ws.WebSocketDisconnect, httpx_ws.WebSocketDisconnect) -# https://asgi.readthedocs.io/en/latest/specs/www.html#websocket-connection-scope -SUPPORTED_WS_HTTP_VERSIONS = ("1.1",) -"""The http versions that we supported now. It depends on `httpx`.""" +#################### Constant #################### #################### Error #################### @@ -278,111 +275,94 @@ async def _starlette_ws_send_bytes_or_str( async def _wait_client_then_send_to_server( - client_ws: starlette_ws.WebSocket, server_ws: httpx_ws.AsyncWebSocketSession -) -> NoReturn: + client_ws: starlette_ws.WebSocket, + server_ws: httpx_ws.AsyncWebSocketSession, + ws_disconnect_deque: "deque[_WsDisconnectType]", + task_group: anyio.abc.TaskGroup, +) -> None: """Receive data from client, then send to target server. Args: client_ws: The websocket which receive data of client. server_ws: The websocket which send data to target server. + ws_disconnect_deque: A deque to store the `WebSocketDisconnect` exception. + task_group: The task group which run this task. + if a `WebSocketDisconnect` is raised, will cancel the task group. Returns: - NoReturn: Never return. Always run forever, except encounter an error, then raise it. + None: Always run forever, except encounter `WebSocketDisconnect`. Raises: error for receiving: refer to `_starlette_ws_receive_bytes_or_str`. - starlette.websockets.WebSocketDisconnect: If the WebSocket is disconnected. - - **This is normal behavior that you should catch**. error for sending: refer to `_httpx_ws_send_bytes_or_str`. """ - while True: - receive = await _starlette_ws_receive_bytes_or_str(client_ws) - await _httpx_ws_send_bytes_or_str(server_ws, receive) + try: + while True: + receive = await _starlette_ws_receive_bytes_or_str(client_ws) + await _httpx_ws_send_bytes_or_str(server_ws, receive) + except _WsDisconnectErrors as ws_disconnect: + task_group.cancel_scope.cancel() + with anyio.CancelScope(shield=True): + ws_disconnect_deque.append(ws_disconnect) async def _wait_server_then_send_to_client( - client_ws: starlette_ws.WebSocket, server_ws: httpx_ws.AsyncWebSocketSession -) -> NoReturn: + client_ws: starlette_ws.WebSocket, + server_ws: httpx_ws.AsyncWebSocketSession, + ws_disconnect_deque: "deque[_WsDisconnectType]", + task_group: anyio.abc.TaskGroup, +) -> None: """Receive data from target server, then send to client. Args: - client_ws: The websocket which send data to client. - server_ws: The websocket which receive data of target server. + client_ws: The websocket which receive data of client. + server_ws: The websocket which send data to target server. + ws_disconnect_deque: A deque to store the `WebSocketDisconnect` exception. + task_group: The task group which run this task. + if a `WebSocketDisconnect` is raised, will cancel the task group. Returns: - NoReturn: Never return. Always run forever, except encounter an error, then raise it. + None: Always run forever, except encounter `WebSocketDisconnect`. Raises: error for receiving: refer to `_httpx_ws_receive_bytes_or_str`. - httpx_ws.WebSocketDisconnect: If the WebSocket is disconnected. - - **This is normal behavior that you should catch** error for sending: refer to `_starlette_ws_send_bytes_or_str`. """ - while True: - receive = await _httpx_ws_receive_bytes_or_str(server_ws) - await _starlette_ws_send_bytes_or_str(client_ws, receive) + try: + while True: + receive = await _httpx_ws_receive_bytes_or_str(server_ws) + await _starlette_ws_send_bytes_or_str(client_ws, receive) + except _WsDisconnectErrors as ws_disconnect: + task_group.cancel_scope.cancel() + with anyio.CancelScope(shield=True): + ws_disconnect_deque.append(ws_disconnect) -async def _close_ws( - excgroup: _WsExceptionGroupType, - /, - *, +async def _close_ws_abnormally( client_ws: starlette_ws.WebSocket, server_ws: httpx_ws.AsyncWebSocketSession, + exc: BaseException, ) -> None: - """Close ws connection and send status code based on `excgroup`. + """Log the exception and close both websockets with code `1011`. Args: - excgroup: The exception group raised when running both client and server proxy tasks. - There should be at most 2 exceptions, one for client, one for server. - If contains `starlette_ws.WebSocketDisconnect`, then will close `server_ws`; - If contains `httpx_ws.WebSocketDisconnect`, then will close `client_ws`. - Else, will close both ws connections with status code `1011`. + exc: The exception propagated by task group. client_ws: client_ws server_ws: server_ws - """ - assert ( - len(excgroup.exceptions) <= 2 - ), "There should be at most 2 exceptions, one for client, one for server." - - client_ws_disc_group = ( - excgroup.subgroup( # pyright: ignore[reportUnknownMemberType] - starlette_ws.WebSocketDisconnect - ) - ) - if client_ws_disc_group: - client_disconnect = client_ws_disc_group.exceptions[0] - # XXX: `isinstance` to make pyright happy - assert isinstance(client_disconnect, starlette_ws.WebSocketDisconnect) - return await server_ws.close(client_disconnect.code, client_disconnect.reason) - - server_ws_disc_group = ( - excgroup.subgroup( # pyright: ignore[reportUnknownMemberType] - httpx_ws.WebSocketDisconnect - ) - ) - if server_ws_disc_group: - server_disconnect = server_ws_disc_group.exceptions[0] - # XXX: `isinstance` to make pyright happy - assert isinstance(server_disconnect, httpx_ws.WebSocketDisconnect) - return await client_ws.close(server_disconnect.code, server_disconnect.reason) - - # 如果上述情况都没有发生,意味着至少其中一个任务发生了异常,导致了另一个任务被取消 client_info = client_ws.client client_host, client_port = ( (client_info.host, client_info.port) if client_info is not None else (None, None) ) - # 这里不用dedent是为了更好的性能 + # we don't use `dedent` here for better performance msg = f"""\ -An error group occurred in the websocket connection for {client_host}:{client_port}. -error group: {excgroup.exceptions}\ +An error occurred in the websocket proxy connection for {client_host}:{client_port}. +errors: {exc!r}\ """ logging.warning(msg) - # Anyway, we should close both ws connections. # Why we use `1011` code, refer to: # https://developer.mozilla.org/zh-CN/docs/Web/API/CloseEvent # https://datatracker.ietf.org/doc/html/rfc6455#section-7.4.1 @@ -390,6 +370,34 @@ async def _close_ws( await server_ws.close(starlette_status.WS_1011_INTERNAL_ERROR) +async def _close_ws_normally( + client_ws: starlette_ws.WebSocket, + server_ws: httpx_ws.AsyncWebSocketSession, + ws_disconnect_deque: "deque[_WsDisconnectType]", +) -> None: + deque_len = len(ws_disconnect_deque) + if deque_len == 1: + ws_disconnect = ws_disconnect_deque[0] + if isinstance(ws_disconnect, starlette_ws.WebSocketDisconnect): + await server_ws.close(ws_disconnect.code, ws_disconnect.reason) + else: + await client_ws.close(ws_disconnect.code, ws_disconnect.reason) + elif deque_len == 2: + # If both client and server are disconnected, we do nothing. + ws_disc_type = {type(ws_disc) for ws_disc in ws_disconnect_deque} + assert ws_disc_type == { + starlette_ws.WebSocketDisconnect, + httpx_ws.WebSocketDisconnect, + } + logging.info( + f"Both client and server received disconnect. {ws_disconnect_deque!r}" + ) + else: + raise AssertionError( + f"There are too many WebSocketDisconnect in deque! {ws_disconnect_deque!r}" + ) + + #################### # #################### @@ -461,8 +469,6 @@ async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOv ) -> Union[Literal[False], StarletteResponse]: """Establish websocket connection for both client and target_url, then pass messages between them. - - The http version of request must be in [`SUPPORTED_WS_HTTP_VERSIONS`][fastapi_proxy_lib.core.websocket.SUPPORTED_WS_HTTP_VERSIONS]. - Args: websocket: The client websocket requests. target_url: The url of target websocket server. @@ -492,13 +498,6 @@ async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOv ) client_request_params: "QueryParamTypes" = websocket.query_params - # TODO: 是否可以不检查http版本? - check_result = check_http_version(websocket.scope, SUPPORTED_WS_HTTP_VERSIONS) - if check_result is not None: - # NOTE: return 之前最好关闭websocket - await websocket.close() - return check_result - # DEBUG: 用于调试的记录 logging.debug( "WS: client:%s ; url:%s ; params:%s ; headers:%s", @@ -535,12 +534,12 @@ async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOv follow_redirects=follow_redirects, ) ) - except httpx_ws.WebSocketUpgradeError as e: + except httpx_ws.WebSocketUpgradeError as ws_upgrade_exc: # 这个错误是在 httpx.stream 获取到响应后才返回的, 也就是说至少本服务器的网络应该是正常的 # 且对于反向ws代理来说,本服务器管理者有义务保证与目标服务器的连接是正常的 # 所以这里既有可能是客户端的错误,或者是目标服务器拒绝了连接 # TODO: 也有可能是本服务器的未知错误 - proxy_res = e.response + proxy_res = ws_upgrade_exc.response # NOTE: return 之前最好关闭websocket # 不调用websocket.accept就发送关闭请求,uvicorn会自动发送403错误 @@ -570,29 +569,48 @@ async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOv # headers=... ) + ws_disconnect_deque: "deque[_WsDisconnectType]" = deque(maxlen=2) + caught_tg_exc = None try: async with anyio.create_task_group() as tg: tg.start_soon( _wait_client_then_send_to_server, websocket, proxy_ws, + ws_disconnect_deque, + tg, name="client_to_server_task", ) tg.start_soon( _wait_server_then_send_to_client, websocket, proxy_ws, + ws_disconnect_deque, + tg, name="server_to_client_task", ) - # XXX: `ExceptionGroup[Any]` is illegal, so we have to ignore the type issue - except ( - ExceptionGroup - ) as excgroup: # pyright: ignore[reportUnknownVariableType] - await _close_ws( - excgroup, # pyright: ignore[reportUnknownArgumentType] - client_ws=websocket, - server_ws=proxy_ws, - ) + except BaseException as base_exc: + caught_tg_exc = base_exc + raise # NOTE: must raise again + finally: + # NOTE: DO NOT use `return` in `finally` block + with anyio.CancelScope(shield=True): + # If there are normal disconnection info, + # we try to close the connection normally. + if ws_disconnect_deque: + await _close_ws_normally( + client_ws=websocket, + server_ws=proxy_ws, + ws_disconnect_deque=ws_disconnect_deque, + ) + else: + caught_tg_exc = caught_tg_exc or RuntimeError("Unknown error") + await _close_ws_abnormally( + client_ws=websocket, + server_ws=proxy_ws, + exc=caught_tg_exc, + ) + return False @override diff --git a/tests/app/echo_ws_app.py b/tests/app/echo_ws_app.py index 132e0bd..c483f8c 100644 --- a/tests/app/echo_ws_app.py +++ b/tests/app/echo_ws_app.py @@ -87,14 +87,17 @@ async def reject_handshake(websocket: WebSocket): await websocket.close() - @app.websocket("/do_nothing") + @app.websocket("/receive_and_send_text_once_without_closing") async def do_nothing(websocket: WebSocket): - """Will do nothing except `websocket.accept()`.""" + """Will receive text once and send it back once, without closing ws.""" nonlocal test_app_dataclass test_app_dataclass.request_dict["request"] = websocket await websocket.accept() + recev = await websocket.receive_text() + await websocket.send_text(recev) + return test_app_dataclass diff --git a/tests/app/tool.py b/tests/app/tool.py index 9236e91..f3b8f85 100644 --- a/tests/app/tool.py +++ b/tests/app/tool.py @@ -258,6 +258,7 @@ async def __aenter__(self) -> Self: self._exit_stack = AsyncExitStack() await self._exit_stack.enter_async_context(self.server) + await anyio.sleep(0.5) # XXX, HACK: wait for server to start return self async def __aexit__(self, *_: Any, **__: Any) -> None: diff --git a/tests/test_ws.py b/tests/test_ws.py index 8880cc6..488a29e 100644 --- a/tests/test_ws.py +++ b/tests/test_ws.py @@ -51,6 +51,7 @@ async def run(): while True: # run forever await anyio.sleep(0.1) + # It's not proxy app server for which we test, so it's ok to not use trio backend anyio.run(run) @@ -63,22 +64,29 @@ def _subprocess_run_httpx_ws( Args: queue: The queue for subprocess to put something for flag of ws connection established. aconnect_ws_url: The websocket url for aconnect_ws. + will add "receive_and_send_text_once_without_closing" to the url. """ async def run(): _exit_stack = AsyncExitStack() _temp_client = httpx.AsyncClient(mounts=NO_PROXIES) - _ = await _exit_stack.enter_async_context( + ws = await _exit_stack.enter_async_context( aconnect_ws( client=_temp_client, - url=aconnect_ws_url, + url=aconnect_ws_url + "receive_and_send_text_once_without_closing", ) ) + # make sure ws is connected + msg = "foo" + await ws.send_text(msg) + await ws.receive_text() + # use queue to notify the connection established queue.put("done") queue.close() while True: # run forever await anyio.sleep(0.1) + # It's not proxy app server for which we test, so it's ok to not use trio backend anyio.run(run) @@ -183,7 +191,7 @@ async def test_ws_proxy(self, tool_4_test_fixture: Tool4TestFixture) -> None: # 是因为这里已经有现成的target server,放在这里测试可以节省启动服务器时间 aconnect_ws_subprocess_queue: "Queue[str]" = Queue() - aconnect_ws_url = proxy_server_base_url + "do_nothing" + aconnect_ws_url = proxy_server_base_url aconnect_ws_subprocess = Process( target=_subprocess_run_httpx_ws, @@ -247,12 +255,19 @@ async def test_target_server_shutdown_abnormally(self) -> None: proxy_server_base_url = str(proxy_ws_server.contx_socket_url) async with aconnect_ws( - proxy_server_base_url + "do_nothing", + proxy_server_base_url + "echo_text", httpx.AsyncClient(mounts=NO_PROXIES), ) as ws0, aconnect_ws( - proxy_server_base_url + "do_nothing", + proxy_server_base_url + "echo_text", httpx.AsyncClient(mounts=NO_PROXIES), ) as ws1: + # make sure ws is connected + msg = "foo" + await ws0.send_text(msg) + assert msg == await ws0.receive_text() + await ws1.send_text(msg) + assert msg == await ws1.receive_text() + # force shutdown target server target_ws_server_subprocess.terminate() target_ws_server_subprocess.kill() From 11f9ff93f585e323191b287f0056c6e010c2cdad Mon Sep 17 00:00:00 2001 From: WSH032 <614337162@qq.com> Date: Sat, 6 Apr 2024 00:10:21 +0800 Subject: [PATCH 09/11] feat!: remove `path` parameter in `proxy` method --- src/fastapi_proxy_lib/core/http.py | 30 +++++++------------------ src/fastapi_proxy_lib/core/websocket.py | 13 ++++------- src/fastapi_proxy_lib/fastapi/router.py | 8 +++---- tests/test_docs_examples.py | 18 ++++++--------- 4 files changed, 23 insertions(+), 46 deletions(-) diff --git a/src/fastapi_proxy_lib/core/http.py b/src/fastapi_proxy_lib/core/http.py index d9e1b09..c3a7613 100644 --- a/src/fastapi_proxy_lib/core/http.py +++ b/src/fastapi_proxy_lib/core/http.py @@ -327,8 +327,8 @@ async def close_proxy_event(_: FastAPI) -> AsyncIterator[None]: # (1)! app = FastAPI(lifespan=close_proxy_event) @app.get("/{path:path}") # (2)! - async def _(request: Request, path: str = ""): - return await proxy.proxy(request=request, path=path) # (3)! + async def _(request: Request): + return await proxy.proxy(request=request) # (3)! # Then run shell: `uvicorn :app --host http://127.0.0.1:8000 --port 8000` # visit the app: `http://127.0.0.1:8000/` @@ -339,10 +339,6 @@ async def _(request: Request, path: str = ""): 2. `{path:path}` is the key.
It allows the app to accept all path parameters.
visit for more info. - 3. !!! info - In fact, you only need to pass the `request: Request` argument.
- `fastapi_proxy_lib` can automatically get the `path` from `request`.
- Explicitly pointing it out here is just to remind you not to forget to specify `{path:path}`. ''' client: httpx.AsyncClient @@ -376,15 +372,12 @@ def __init__( @override async def proxy( # pyright: ignore [reportIncompatibleMethodOverride] - self, *, request: StarletteRequest, path: Optional[str] = None + self, *, request: StarletteRequest ) -> StarletteResponse: """Send request to target server. Args: request: `starlette.requests.Request` - path: The path params of request, which means the path params of base url.
- If None, will get it from `request.path_params`.
- **Usually, you don't need to pass this argument**. Returns: The response from target server. @@ -392,9 +385,7 @@ async def proxy( # pyright: ignore [reportIncompatibleMethodOverride] base_url = self.base_url # 只取第一个路径参数。注意,我们允许没有路径参数,这代表直接请求 - path_param: str = ( - path if path is not None else next(iter(request.path_params.values()), "") - ) + path_param: str = next(iter(request.path_params.values()), "") # 将路径参数拼接到目标url上 # e.g: "https://www.example.com/p0/" + "p1" @@ -462,8 +453,8 @@ async def close_proxy_event(_: FastAPI) -> AsyncIterator[None]: app = FastAPI(lifespan=close_proxy_event) @app.get("/{path:path}") - async def _(request: Request, path: str = ""): - return await proxy.proxy(request=request, path=path) + async def _(request: Request): + return await proxy.proxy(request=request) # Then run shell: `uvicorn :app --host http://127.0.0.1:8000 --port 8000` # visit the app: `http://127.0.0.1:8000/http://www.example.com` @@ -502,15 +493,11 @@ async def proxy( # pyright: ignore [reportIncompatibleMethodOverride] self, *, request: StarletteRequest, - path: Optional[str] = None, ) -> StarletteResponse: """Send request to target server. Args: request: `starlette.requests.Request` - path: The path params of request, which means the full url of target server.
- If None, will get it from `request.path_params`.
- **Usually, you don't need to pass this argument**. Returns: The response from target server. @@ -518,9 +505,8 @@ async def proxy( # pyright: ignore [reportIncompatibleMethodOverride] proxy_filter = self.proxy_filter # 只取第一个路径参数 - path_param: str = ( - next(iter(request.path_params.values()), "") if path is None else path - ) + path_param: str = next(iter(request.path_params.values()), "") + # 如果没有路径参数,即在正向代理中未指定目标url,则返回400 if path_param == "": error = _BadTargetUrlError("Must provide target url.") diff --git a/src/fastapi_proxy_lib/core/websocket.py b/src/fastapi_proxy_lib/core/websocket.py index 99ded30..a785d23 100644 --- a/src/fastapi_proxy_lib/core/websocket.py +++ b/src/fastapi_proxy_lib/core/websocket.py @@ -671,8 +671,8 @@ async def close_proxy_event(_: FastAPI) -> AsyncIterator[None]: app = FastAPI(lifespan=close_proxy_event) @app.websocket("/{path:path}") - async def _(websocket: WebSocket, path: str = ""): - return await proxy.proxy(websocket=websocket, path=path) + async def _(websocket: WebSocket): + return await proxy.proxy(websocket=websocket) # Then run shell: `uvicorn :app --host http://127.0.0.1:8000 --port 8000` # visit the app: `ws://127.0.0.1:8000/` @@ -737,15 +737,12 @@ def __init__( @override async def proxy( # pyright: ignore [reportIncompatibleMethodOverride] - self, *, websocket: starlette_ws.WebSocket, path: Optional[str] = None + self, *, websocket: starlette_ws.WebSocket ) -> Union[Literal[False], StarletteResponse]: """Establish websocket connection for both client and target_url, then pass messages between them. Args: websocket: The client websocket requests. - path: The path params of websocket request, which means the path params of base url.
- If None, will get it from `websocket.path_params`.
- **Usually, you don't need to pass this argument**. Returns: If the establish websocket connection unsuccessfully: @@ -757,9 +754,7 @@ async def proxy( # pyright: ignore [reportIncompatibleMethodOverride] base_url = self.base_url # 只取第一个路径参数。注意,我们允许没有路径参数,这代表直接请求 - path_param: str = ( - path if path is not None else next(iter(websocket.path_params.values()), "") - ) + path_param: str = next(iter(websocket.path_params.values()), "") # 将路径参数拼接到目标url上 # e.g: "https://www.example.com/p0/" + "p1" diff --git a/src/fastapi_proxy_lib/fastapi/router.py b/src/fastapi_proxy_lib/fastapi/router.py index 32ee808..c799864 100644 --- a/src/fastapi_proxy_lib/fastapi/router.py +++ b/src/fastapi_proxy_lib/fastapi/router.py @@ -63,7 +63,7 @@ def _http_register_router( @router.patch("/{path:path}", **kwargs) @router.trace("/{path:path}", **kwargs) async def http_proxy( # pyright: ignore[reportUnusedFunction] - request: Request, path: str = "" + request: Request, ) -> Response: """HTTP proxy endpoint. @@ -74,7 +74,7 @@ async def http_proxy( # pyright: ignore[reportUnusedFunction] Returns: The response from target server. """ - return await proxy.proxy(request=request, path=path) + return await proxy.proxy(request=request) def _ws_register_router( @@ -96,7 +96,7 @@ def _ws_register_router( @router.websocket("/{path:path}", **kwargs) async def ws_proxy( # pyright: ignore[reportUnusedFunction] - websocket: WebSocket, path: str = "" + websocket: WebSocket, ) -> Union[Response, Literal[False]]: """WebSocket proxy endpoint. @@ -111,7 +111,7 @@ async def ws_proxy( # pyright: ignore[reportUnusedFunction] If the establish websocket connection successfully: - Will run forever until the connection is closed. Then return False. """ - return await proxy.proxy(websocket=websocket, path=path) + return await proxy.proxy(websocket=websocket) class RouterHelper: diff --git a/tests/test_docs_examples.py b/tests/test_docs_examples.py index 12bea06..d7a9a3a 100644 --- a/tests/test_docs_examples.py +++ b/tests/test_docs_examples.py @@ -23,8 +23,8 @@ async def close_proxy_event(_: FastAPI) -> AsyncIterator[None]: app = FastAPI(lifespan=close_proxy_event) @app.get("/{path:path}") - async def _(request: Request, path: str = ""): - return await proxy.proxy(request=request, path=path) + async def _(request: Request): + return await proxy.proxy(request=request) # Then run shell: `uvicorn :app --host http://127.0.0.1:8000 --port 8000` # visit the app: `http://127.0.0.1:8000/http://www.example.com` @@ -52,8 +52,8 @@ async def close_proxy_event(_: FastAPI) -> AsyncIterator[None]: # (1)! app = FastAPI(lifespan=close_proxy_event) @app.get("/{path:path}") # (2)! - async def _(request: Request, path: str = ""): - return await proxy.proxy(request=request, path=path) # (3)! + async def _(request: Request): + return await proxy.proxy(request=request) # (3)! # Then run shell: `uvicorn :app --host http://127.0.0.1:8000 --port 8000` # visit the app: `http://127.0.0.1:8000/` @@ -62,11 +62,7 @@ async def _(request: Request, path: str = ""): """ 1. lifespan please refer to [starlette/lifespan](https://www.starlette.io/lifespan/) 2. `{path:path}` is the key.
It allows the app to accept all path parameters.
- visit for more info. - 3. !!! info - In fact, you only need to pass the `request: Request` argument.
- `fastapi_proxy_lib` can automatically get the `path` from `request`.
- Explicitly pointing it out here is just to remind you not to forget to specify `{path:path}`. """ + visit for more info. """ def test_reverse_ws_proxy() -> None: @@ -90,8 +86,8 @@ async def close_proxy_event(_: FastAPI) -> AsyncIterator[None]: app = FastAPI(lifespan=close_proxy_event) @app.websocket("/{path:path}") - async def _(websocket: WebSocket, path: str = ""): - return await proxy.proxy(websocket=websocket, path=path) + async def _(websocket: WebSocket): + return await proxy.proxy(websocket=websocket) # Then run shell: `uvicorn :app --host http://127.0.0.1:8000 --port 8000` # visit the app: `ws://127.0.0.1:8000/` From 6ddb50ed64b2bd2cb2e51b2e1346fd484fafcdf9 Mon Sep 17 00:00:00 2001 From: WSH032 <614337162@qq.com> Date: Sun, 7 Apr 2024 08:24:30 +0800 Subject: [PATCH 10/11] feat!: support sending denial response, handshake headers, and better handling of closure BREAKING CHANGE: The return signature of `ReverseWebSocketProxy.proxy` has been changed; it no longer returns `StarletteResponse`, but instead returns `True` --- pyproject.toml | 4 +- src/fastapi_proxy_lib/core/websocket.py | 280 ++++++++++++++---------- src/fastapi_proxy_lib/fastapi/router.py | 9 +- tests/app/echo_ws_app.py | 37 +++- tests/app/tool.py | 30 +-- tests/test_ws.py | 88 +++++++- 6 files changed, 289 insertions(+), 159 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b012ebc..2d7318c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,8 +50,8 @@ dynamic = ["version"] dependencies = [ "httpx", - "httpx-ws >= 0.5.2", - "starlette", + "httpx-ws >= 0.6.0", + "starlette >= 0.37.2", "typing_extensions >=4.5.0", "anyio >= 4", "exceptiongroup", diff --git a/src/fastapi_proxy_lib/core/websocket.py b/src/fastapi_proxy_lib/core/websocket.py index a785d23..74ffb94 100644 --- a/src/fastapi_proxy_lib/core/websocket.py +++ b/src/fastapi_proxy_lib/core/websocket.py @@ -9,7 +9,6 @@ TYPE_CHECKING, Any, List, - Literal, NoReturn, Optional, Union, @@ -20,9 +19,10 @@ import httpx import httpx_ws import starlette.websockets as starlette_ws +import wsproto from starlette import status as starlette_status -from starlette.responses import Response as StarletteResponse -from starlette.responses import StreamingResponse +from starlette.background import BackgroundTask +from starlette.responses import Response from starlette.types import Scope from typing_extensions import override from wsproto.events import BytesMessage as WsprotoBytesMessage @@ -85,7 +85,10 @@ starlette_ws.WebSocketDisconnect, httpx_ws.WebSocketDisconnect ] -_WsDisconnectErrors = (starlette_ws.WebSocketDisconnect, httpx_ws.WebSocketDisconnect) +_WsDisconnectDequqType = Union[ + _WsDisconnectType, # Exception that contains closing info + Exception, # other Exception +] #################### Constant #################### @@ -257,11 +260,7 @@ async def _starlette_ws_send_bytes_or_str( data: The data to send. Raises: - When websocket has been disconnected, there may be exceptions raised, or maybe not. - # https://github.com/encode/uvicorn/discussions/2137 - For Uvicorn backend: - - `wsproto`: nothing raised. - - `websockets`: websockets.exceptions.ConnectionClosedError + When websocket has been disconnected, will raise `starlette_ws.WebSocketDisconnect(1006)`. """ # HACK: make pyright happy @@ -277,7 +276,7 @@ async def _starlette_ws_send_bytes_or_str( async def _wait_client_then_send_to_server( client_ws: starlette_ws.WebSocket, server_ws: httpx_ws.AsyncWebSocketSession, - ws_disconnect_deque: "deque[_WsDisconnectType]", + ws_disconnect_deque: "deque[_WsDisconnectDequqType]", task_group: anyio.abc.TaskGroup, ) -> None: """Receive data from client, then send to target server. @@ -285,12 +284,12 @@ async def _wait_client_then_send_to_server( Args: client_ws: The websocket which receive data of client. server_ws: The websocket which send data to target server. - ws_disconnect_deque: A deque to store the `WebSocketDisconnect` exception. + ws_disconnect_deque: A deque to store Exception. task_group: The task group which run this task. - if a `WebSocketDisconnect` is raised, will cancel the task group. + if catch a Exception, will cancel the task group. Returns: - None: Always run forever, except encounter `WebSocketDisconnect`. + None: Always run forever, except encounter Exception. Raises: error for receiving: refer to `_starlette_ws_receive_bytes_or_str`. @@ -300,16 +299,15 @@ async def _wait_client_then_send_to_server( while True: receive = await _starlette_ws_receive_bytes_or_str(client_ws) await _httpx_ws_send_bytes_or_str(server_ws, receive) - except _WsDisconnectErrors as ws_disconnect: + except Exception as ws_disconnect: task_group.cancel_scope.cancel() - with anyio.CancelScope(shield=True): - ws_disconnect_deque.append(ws_disconnect) + ws_disconnect_deque.append(ws_disconnect) async def _wait_server_then_send_to_client( client_ws: starlette_ws.WebSocket, server_ws: httpx_ws.AsyncWebSocketSession, - ws_disconnect_deque: "deque[_WsDisconnectType]", + ws_disconnect_deque: "deque[_WsDisconnectDequqType]", task_group: anyio.abc.TaskGroup, ) -> None: """Receive data from target server, then send to client. @@ -317,12 +315,12 @@ async def _wait_server_then_send_to_client( Args: client_ws: The websocket which receive data of client. server_ws: The websocket which send data to target server. - ws_disconnect_deque: A deque to store the `WebSocketDisconnect` exception. + ws_disconnect_deque: A deque to store Exception. task_group: The task group which run this task. - if a `WebSocketDisconnect` is raised, will cancel the task group. + if catch a Exception, will cancel the task group. Returns: - None: Always run forever, except encounter `WebSocketDisconnect`. + None: Always run forever, except encounter Exception. Raises: error for receiving: refer to `_httpx_ws_receive_bytes_or_str`. @@ -332,70 +330,135 @@ async def _wait_server_then_send_to_client( while True: receive = await _httpx_ws_receive_bytes_or_str(server_ws) await _starlette_ws_send_bytes_or_str(client_ws, receive) - except _WsDisconnectErrors as ws_disconnect: + except Exception as ws_disconnect: task_group.cancel_scope.cancel() - with anyio.CancelScope(shield=True): - ws_disconnect_deque.append(ws_disconnect) + ws_disconnect_deque.append(ws_disconnect) -async def _close_ws_abnormally( +async def _close_ws( # noqa: C901, PLR0912 client_ws: starlette_ws.WebSocket, server_ws: httpx_ws.AsyncWebSocketSession, - exc: BaseException, -) -> None: - """Log the exception and close both websockets with code `1011`. + ws_disconnect_deque: "deque[_WsDisconnectDequqType]", + caught_tg_exc: Optional[BaseException], +): + ws_disconnect_tuple = tuple(ws_disconnect_deque) + + client_disc_errs: List[starlette_ws.WebSocketDisconnect] = [] + not_client_disc_errs: List[Exception] = [] + for e in ws_disconnect_tuple: + if isinstance(e, starlette_ws.WebSocketDisconnect): + client_disc_errs.append(e) + else: + not_client_disc_errs.append(e) + + server_disc_errs: List[httpx_ws.WebSocketDisconnect] = [] + not_server_disc_errs: List[Exception] = [] + for e in ws_disconnect_tuple: + if isinstance(e, httpx_ws.WebSocketDisconnect): + server_disc_errs.append(e) + else: + not_server_disc_errs.append(e) + + is_canceled = isinstance(caught_tg_exc, anyio.get_cancelled_exc_class()) - Args: - exc: The exception propagated by task group. - client_ws: client_ws - server_ws: server_ws - """ - client_info = client_ws.client client_host, client_port = ( - (client_info.host, client_info.port) - if client_info is not None + (client_ws.client.host, client_ws.client.port) + if client_ws.client is not None else (None, None) ) - # we don't use `dedent` here for better performance - msg = f"""\ -An error occurred in the websocket proxy connection for {client_host}:{client_port}. -errors: {exc!r}\ -""" - logging.warning(msg) - # Why we use `1011` code, refer to: - # https://developer.mozilla.org/zh-CN/docs/Web/API/CloseEvent - # https://datatracker.ietf.org/doc/html/rfc6455#section-7.4.1 - await client_ws.close(starlette_status.WS_1011_INTERNAL_ERROR) - await server_ws.close(starlette_status.WS_1011_INTERNAL_ERROR) + # Implement reference: + # https://github.com/encode/starlette/blob/4e453ce91940cc7c995e6c728e3fdf341c039056/starlette/websockets.py#L64-L112 + client_ws_closed_state = { + starlette_ws.WebSocketState.DISCONNECTED, + starlette_ws.WebSocketState.RESPONSE, + } + if ( + client_ws.application_state not in client_ws_closed_state + and client_ws.client_state not in client_ws_closed_state + ): + if server_disc_errs: + server_disc = server_disc_errs[0] + await client_ws.close(server_disc.code, server_disc.reason) + elif is_canceled: + await client_ws.close(starlette_status.WS_1001_GOING_AWAY) + else: + await client_ws.close(starlette_status.WS_1011_INTERNAL_ERROR) + logging.warning( + f"[{client_host}:{client_port}] Client websocket closed abnormally during proxying. " + f"Catch tasks exceptions: {not_server_disc_errs!r} " + f"Catch task group exceptions: {caught_tg_exc!r}" + ) + + # Implement reference: + # https://github.com/frankie567/httpx-ws/blob/940c9adb3afee9dd7c8b95514bdf6444673e4820/httpx_ws/_api.py#L928-L931 + if server_ws.connection.state not in { + wsproto.connection.ConnectionState.CLOSED, + wsproto.connection.ConnectionState.LOCAL_CLOSING, + }: + if client_disc_errs: + client_disc = client_disc_errs[0] + await server_ws.close(client_disc.code, client_disc.reason) + elif is_canceled: + await server_ws.close(starlette_status.WS_1001_GOING_AWAY) + else: + # If remote server has closed normally, here we just close local ws. + # It's normal, so we don't need warning. + if ( + server_ws.connection.state + != wsproto.connection.ConnectionState.REMOTE_CLOSING + ): + logging.warning( + f"[{client_host}:{client_port}] Server websocket closed abnormally during proxying. " + f"Catch tasks exceptions: {not_client_disc_errs!r} " + f"Catch task group exceptions: {caught_tg_exc!r}" + ) + await server_ws.close(starlette_status.WS_1011_INTERNAL_ERROR) -async def _close_ws_normally( +async def _handle_ws_upgrade_error( client_ws: starlette_ws.WebSocket, - server_ws: httpx_ws.AsyncWebSocketSession, - ws_disconnect_deque: "deque[_WsDisconnectType]", + background: BackgroundTask, + ws_upgrade_exc: httpx_ws.WebSocketUpgradeError, ) -> None: - deque_len = len(ws_disconnect_deque) - if deque_len == 1: - ws_disconnect = ws_disconnect_deque[0] - if isinstance(ws_disconnect, starlette_ws.WebSocketDisconnect): - await server_ws.close(ws_disconnect.code, ws_disconnect.reason) - else: - await client_ws.close(ws_disconnect.code, ws_disconnect.reason) - elif deque_len == 2: - # If both client and server are disconnected, we do nothing. - ws_disc_type = {type(ws_disc) for ws_disc in ws_disconnect_deque} - assert ws_disc_type == { - starlette_ws.WebSocketDisconnect, - httpx_ws.WebSocketDisconnect, - } - logging.info( - f"Both client and server received disconnect. {ws_disconnect_deque!r}" + proxy_res = ws_upgrade_exc.response + # https://asgi.readthedocs.io/en/latest/extensions.html#websocket-denial-response + # https://github.com/encode/starlette/blob/4e453ce91940cc7c995e6c728e3fdf341c039056/starlette/websockets.py#L207-L214 + is_able_to_send_denial_response = "websocket.http.response" in client_ws.scope.get( + "extensions", {} + ) + + if is_able_to_send_denial_response: + # # XXX: Can not use send_denial_response with StreamingResponse + # # See: https://github.com/encode/starlette/discussions/2566 + # denial_response = StreamingResponse( + # content=proxy_res.aiter_raw(), + # status_code=proxy_res.status_code, + # headers=proxy_res.headers, + # background=background, + # ) + + # # XXX: Unable to read the content of WebSocketUpgradeError.response + # # See: https://github.com/frankie567/httpx-ws/discussions/69 + # content = await proxy_res.aread() + + denial_response = Response( + content="", + status_code=proxy_res.status_code, + headers=proxy_res.headers, + background=background, ) + await client_ws.send_denial_response(denial_response) else: - raise AssertionError( - f"There are too many WebSocketDisconnect in deque! {ws_disconnect_deque!r}" + msg = ( + "Proxy websocket handshake failed, " + "but your ASGI server does not support sending denial response.\n" + f"Denial response: {proxy_res!r}" ) + logging.warning(msg) + # we close before accept, then ASGI will send 403 to client + await client_ws.close() + await background() #################### # #################### @@ -466,7 +529,7 @@ async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOv *, websocket: starlette_ws.WebSocket, target_url: httpx.URL, - ) -> Union[Literal[False], StarletteResponse]: + ) -> bool: """Establish websocket connection for both client and target_url, then pass messages between them. Args: @@ -474,11 +537,7 @@ async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOv target_url: The url of target websocket server. Returns: - If the establish websocket connection unsuccessfully: - - Will call `websocket.close()` to send code `4xx` - - Then return a `StarletteResponse` from target server - If the establish websocket connection successfully: - - Will run forever until the connection is closed. Then return False. + bool: If handshake failed, return True. Else return False. """ client = self.client follow_redirects = self.follow_redirects @@ -535,41 +594,36 @@ async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOv ) ) except httpx_ws.WebSocketUpgradeError as ws_upgrade_exc: - # 这个错误是在 httpx.stream 获取到响应后才返回的, 也就是说至少本服务器的网络应该是正常的 - # 且对于反向ws代理来说,本服务器管理者有义务保证与目标服务器的连接是正常的 - # 所以这里既有可能是客户端的错误,或者是目标服务器拒绝了连接 - # TODO: 也有可能是本服务器的未知错误 - proxy_res = ws_upgrade_exc.response - - # NOTE: return 之前最好关闭websocket - # 不调用websocket.accept就发送关闭请求,uvicorn会自动发送403错误 - await websocket.close() - # TODO: 连接失败的时候httpx_ws会自己关闭连接,但或许这里显式关闭会更好 - - # HACK: 这里的返回的响应其实uvicorn不会处理 - return StreamingResponse( - content=proxy_res.aiter_raw(), - status_code=proxy_res.status_code, - headers=proxy_res.headers, + await _handle_ws_upgrade_error( + client_ws=websocket, + background=BackgroundTask(stack.aclose), + ws_upgrade_exc=ws_upgrade_exc, ) + return True # NOTE: 对于反向代理服务器,我们不返回 "任何" "具体的内部" 错误信息给客户端,因为这可能涉及到服务器内部的信息泄露 # NOTE: 请使用 with 语句来 "保证关闭" AsyncWebSocketSession async with stack: - # TODO: websocket.accept 中还有一个headers参数,但是httpx_ws不支持,考虑发起PR - # https://github.com/frankie567/httpx-ws/discussions/53 - - # FIXME: 调查缺少headers参数是否会引起问题,及是否会影响透明代理的无损转发性 + proxy_ws_resp = proxy_ws.response + # TODO: Here is a typing issue of `httpx_ws`, we have to use `assert` to make pyright happy + # https://github.com/frankie567/httpx-ws/pull/54#pullrequestreview-1974062119 + assert proxy_ws_resp is not None + headers = proxy_ws_resp.headers.copy() + # ASGI not allow the headers contains `sec-websocket-protocol` field # https://asgi.readthedocs.io/en/latest/specs/www.html#accept-send-event - - # 这时候如果发生错误,退出时 stack 会自动关闭 httpx_ws 连接,所以这里不需要手动关闭 + headers.pop("sec-websocket-protocol", None) + # XXX: uvicorn websockets implementation not allow contains multiple `Date` and `Server` field, + # only wsporoto can do so. + # https://github.com/encode/uvicorn/pull/1606 + # https://github.com/python-websockets/websockets/issues/1226 + headers.pop("Date", None) + headers.pop("Server", None) await websocket.accept( - subprotocol=proxy_ws.subprotocol - # headers=... + subprotocol=proxy_ws.subprotocol, headers=headers.raw ) - ws_disconnect_deque: "deque[_WsDisconnectType]" = deque(maxlen=2) + ws_disconnect_deque: "deque[_WsDisconnectDequqType]" = deque() caught_tg_exc = None try: async with anyio.create_task_group() as tg: @@ -593,23 +647,13 @@ async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOv caught_tg_exc = base_exc raise # NOTE: must raise again finally: - # NOTE: DO NOT use `return` in `finally` block with anyio.CancelScope(shield=True): - # If there are normal disconnection info, - # we try to close the connection normally. - if ws_disconnect_deque: - await _close_ws_normally( - client_ws=websocket, - server_ws=proxy_ws, - ws_disconnect_deque=ws_disconnect_deque, - ) - else: - caught_tg_exc = caught_tg_exc or RuntimeError("Unknown error") - await _close_ws_abnormally( - client_ws=websocket, - server_ws=proxy_ws, - exc=caught_tg_exc, - ) + await _close_ws( + websocket, + proxy_ws, + ws_disconnect_deque, + caught_tg_exc, + ) return False @@ -738,18 +782,14 @@ def __init__( @override async def proxy( # pyright: ignore [reportIncompatibleMethodOverride] self, *, websocket: starlette_ws.WebSocket - ) -> Union[Literal[False], StarletteResponse]: + ) -> bool: """Establish websocket connection for both client and target_url, then pass messages between them. Args: websocket: The client websocket requests. Returns: - If the establish websocket connection unsuccessfully: - - Will call `websocket.close()` to send code `4xx` - - Then return a `StarletteResponse` from target server - If the establish websocket connection successfully: - - Will run forever until the connection is closed. Then return False. + bool: If handshake failed, return True. Else return False. """ base_url = self.base_url diff --git a/src/fastapi_proxy_lib/fastapi/router.py b/src/fastapi_proxy_lib/fastapi/router.py index c799864..346c037 100644 --- a/src/fastapi_proxy_lib/fastapi/router.py +++ b/src/fastapi_proxy_lib/fastapi/router.py @@ -10,7 +10,6 @@ AsyncContextManager, AsyncIterator, Callable, - Literal, Optional, Set, TypeVar, @@ -97,7 +96,7 @@ def _ws_register_router( @router.websocket("/{path:path}", **kwargs) async def ws_proxy( # pyright: ignore[reportUnusedFunction] websocket: WebSocket, - ) -> Union[Response, Literal[False]]: + ) -> bool: """WebSocket proxy endpoint. Args: @@ -105,11 +104,7 @@ async def ws_proxy( # pyright: ignore[reportUnusedFunction] path: The path parameters of request. Returns: - If the establish websocket connection unsuccessfully: - - Will call `websocket.close()` to send code `4xx` - - Then return a `StarletteResponse` from target server - If the establish websocket connection successfully: - - Will run forever until the connection is closed. Then return False. + bool: If handshake failed, return True. Else return False. """ return await proxy.proxy(websocket=websocket) diff --git a/tests/app/echo_ws_app.py b/tests/app/echo_ws_app.py index c483f8c..3fb4b6a 100644 --- a/tests/app/echo_ws_app.py +++ b/tests/app/echo_ws_app.py @@ -4,6 +4,7 @@ import anyio from fastapi import FastAPI, WebSocket +from starlette.responses import JSONResponse from starlette.websockets import WebSocketDisconnect from .tool import AppDataclass4Test, RequestDict @@ -53,8 +54,8 @@ async def echo_bytes(websocket: WebSocket): except WebSocketDisconnect: break - @app.websocket("/accept_foo_subprotocol") - async def accept_foo_subprotocol(websocket: WebSocket): + @app.websocket("/accept_foo_subprotocol_and_foo_bar_header") + async def accept_foo_subprotocol_and_foo_bar_header(websocket: WebSocket): """When client send subprotocols request, if subprotocols contain "foo", will accept it.""" nonlocal test_app_dataclass test_app_dataclass.request_dict["request"] = websocket @@ -65,19 +66,20 @@ async def accept_foo_subprotocol(websocket: WebSocket): else: accepted_subprotocol = None - await websocket.accept(subprotocol=accepted_subprotocol) + await websocket.accept( + subprotocol=accepted_subprotocol, headers=[(b"foo", b"bar")] + ) await websocket.close() - @app.websocket("/just_close_with_1001") - async def just_close_with_1001(websocket: WebSocket): - """Just do nothing after `accept`, then close ws with 1001 code.""" + @app.websocket("/just_close_with_1002_and_foo") + async def just_close_with_1002_and_foo(websocket: WebSocket): + """Just do nothing after `accept`, then close ws with 1001 code and 'foo'.""" nonlocal test_app_dataclass test_app_dataclass.request_dict["request"] = websocket await websocket.accept() - await anyio.sleep(0.3) - await websocket.close(1001) + await websocket.close(1002, "foo") @app.websocket("/reject_handshake") async def reject_handshake(websocket: WebSocket): @@ -87,10 +89,25 @@ async def reject_handshake(websocket: WebSocket): await websocket.close() + @app.websocket("/send_denial_response_400_foo_bar_header_and_json_body") + async def send_denial_response_400_foo_bar_header_and_json_body( + websocket: WebSocket, + ): + """Will reject ws request by just calling `websocket.close()`.""" + nonlocal test_app_dataclass + test_app_dataclass.request_dict["request"] = websocket + + denial_resp = JSONResponse({"foo": "bar"}, 400, headers={"foo": "bar"}) + await websocket.send_denial_response(denial_resp) + @app.websocket("/receive_and_send_text_once_without_closing") async def do_nothing(websocket: WebSocket): - """Will receive text once and send it back once, without closing ws.""" + """Will receive text once and send it back once, without closing ws. + + Note: user must close the ws manually, and call `websocket.state.closing.set()`. + """ nonlocal test_app_dataclass + websocket.state.closing = anyio.Event() test_app_dataclass.request_dict["request"] = websocket await websocket.accept() @@ -98,6 +115,8 @@ async def do_nothing(websocket: WebSocket): recev = await websocket.receive_text() await websocket.send_text(recev) + await websocket.state.closing.wait() + return test_app_dataclass diff --git a/tests/app/tool.py b/tests/app/tool.py index f3b8f85..05cead4 100644 --- a/tests/app/tool.py +++ b/tests/app/tool.py @@ -216,6 +216,8 @@ def contx_socket_url(self) -> httpx.URL: class AutoServer: """An AsyncContext to launch and shutdown Hypercorn or Uvicorn server automatically.""" + server_type: Literal["uvicorn", "hypercorn"] + def __init__( self, app: FastAPI, @@ -229,33 +231,33 @@ def __init__( If `host` == 0, then use random port. """ - server_type = server_type if server_type is not None else "hypercorn" - self.app = app self.host = host self.port = port - self.server_type = server_type + self._server_type: Optional[Literal["uvicorn", "hypercorn"]] = server_type + + async def __aenter__(self) -> Self: + """Launch the server.""" + if self._server_type is None: + if sniffio.current_async_library() == "asyncio": + self.server_type = "uvicorn" + else: + self.server_type = "hypercorn" + else: + self.server_type = self._server_type if self.server_type == "hypercorn": config = HyperConfig() - config.bind = f"{host}:{port}" + config.bind = f"{self.host}:{self.port}" self.config = config - self.server = _HypercornServer(app, config) + self.server = _HypercornServer(self.app, config) elif self.server_type == "uvicorn": - self.config = uvicorn.Config(app, host=host, port=port) + self.config = uvicorn.Config(self.app, host=self.host, port=self.port) self.server = _UvicornServer(self.config) else: assert_never(self.server_type) - async def __aenter__(self) -> Self: - """Launch the server.""" - if ( - self.server_type == "uvicorn" - and sniffio.current_async_library() != "asyncio" - ): - raise RuntimeError("Uvicorn server does not support trio backend.") - self._exit_stack = AsyncExitStack() await self._exit_stack.enter_async_context(self.server) await anyio.sleep(0.5) # XXX, HACK: wait for server to start diff --git a/tests/test_ws.py b/tests/test_ws.py index 488a29e..c76b912 100644 --- a/tests/test_ws.py +++ b/tests/test_ws.py @@ -2,6 +2,7 @@ from contextlib import AsyncExitStack +from dataclasses import dataclass from multiprocessing import Process, Queue from typing import Any, Dict @@ -30,6 +31,12 @@ NO_PROXIES: Dict[Any, Any] = {"all://": None} +@dataclass +class Tool4ServerTestFixture(Tool4TestFixture): # noqa: D101 + target_server: AutoServer + proxy_server: AutoServer + + def _subprocess_run_echo_ws_server(queue: "Queue[str]"): """Run echo ws app in subprocess. @@ -98,7 +105,7 @@ class TestReverseWsProxy(AbstractTestProxy): async def tool_4_test_fixture( # pyright: ignore[reportIncompatibleMethodOverride] self, auto_server_fixture: AutoServerFixture, - ) -> Tool4TestFixture: + ) -> Tool4ServerTestFixture: """目标服务器请参考`tests.app.echo_ws_app.get_app`.""" echo_ws_test_model = get_ws_test_app() echo_ws_app = echo_ws_test_model.app @@ -124,16 +131,20 @@ async def tool_4_test_fixture( # pyright: ignore[reportIncompatibleMethodOverri client_for_conn_to_proxy_server = httpx.AsyncClient(mounts=NO_PROXIES) - return Tool4TestFixture( + return Tool4ServerTestFixture( client_for_conn_to_target_server=client_for_conn_to_target_server, client_for_conn_to_proxy_server=client_for_conn_to_proxy_server, get_request=echo_ws_get_request, target_server_base_url=target_server_base_url, proxy_server_base_url=proxy_server_base_url, + target_server=target_ws_server, + proxy_server=proxy_ws_server, ) @pytest.mark.anyio() - async def test_ws_proxy(self, tool_4_test_fixture: Tool4TestFixture) -> None: + async def test_ws_proxy( # noqa: PLR0915 + self, tool_4_test_fixture: Tool4ServerTestFixture + ) -> None: """测试websocket代理.""" proxy_server_base_url = tool_4_test_fixture.proxy_server_base_url client_for_conn_to_proxy_server = ( @@ -141,6 +152,9 @@ async def test_ws_proxy(self, tool_4_test_fixture: Tool4TestFixture) -> None: ) get_request = tool_4_test_fixture.get_request + target_server = tool_4_test_fixture.target_server + proxy_server = tool_4_test_fixture.proxy_server + ########## 测试数据的正常转发 ########## async with aconnect_ws( @@ -158,21 +172,58 @@ async def test_ws_proxy(self, tool_4_test_fixture: Tool4TestFixture) -> None: ########## 测试子协议 ########## async with aconnect_ws( - proxy_server_base_url + "accept_foo_subprotocol", + proxy_server_base_url + "accept_foo_subprotocol_and_foo_bar_header", client_for_conn_to_proxy_server, subprotocols=["foo", "bar"], ) as ws: assert ws.subprotocol == "foo" + assert ws.response is not None + assert ws.response.headers["foo"] == "bar" - ########## 关闭代码 ########## + ########## 客户端发送关闭代码 ########## + + code = 1003 + reason = "foo" + async with aconnect_ws( + proxy_server_base_url + "receive_and_send_text_once_without_closing", + client_for_conn_to_proxy_server, + ) as ws: + await ws.send_text("foo") + await ws.receive_text() + await ws.close(code=code, reason=reason) + + target_starlette_ws = get_request() + assert isinstance(target_starlette_ws, starlette_websockets_module.WebSocket) + with pytest.raises(starlette_websockets_module.WebSocketDisconnect) as exce: + await target_starlette_ws.receive_text() + + closing_event = target_starlette_ws.state.closing + assert isinstance(closing_event, anyio.Event) + closing_event.set() + + # XXX, HACK, TODO: + # hypercorn can't receive correctly close code, it always receive 1006 + # https://github.com/pgjones/hypercorn/issues/127 + # so we only test close code for uvicorn + if ( + target_server.server_type == "uvicorn" + and proxy_server.server_type == "uvicorn" + ): + assert exce.value.code == code + # XXX, HACK, TODO: + # reaseon are wrong, httpx-ws can't send close reason correctly + # assert exce.value.reason == reason + + ########## 服务端发送关闭代码 ########## async with aconnect_ws( - proxy_server_base_url + "just_close_with_1001", + proxy_server_base_url + "just_close_with_1002_and_foo", client_for_conn_to_proxy_server, ) as ws: with pytest.raises(httpx_ws.WebSocketDisconnect) as exce: await ws.receive_text() - assert exce.value.code == 1001 + assert exce.value.code == 1002 + assert exce.value.reason == "foo" ########## 协议升级失败或者连接失败 ########## @@ -185,6 +236,22 @@ async def test_ws_proxy(self, tool_4_test_fixture: Tool4TestFixture) -> None: # Starlette 在未调用`websocket.accept()`之前调用了`websocket.close()`,会发生403 assert exce.value.response.status_code == 403 + ########## test denial response ########## + + with pytest.raises(httpx_ws.WebSocketUpgradeError) as exce: + async with aconnect_ws( + proxy_server_base_url + + "send_denial_response_400_foo_bar_header_and_json_body", + client_for_conn_to_proxy_server, + ) as ws: + pass + # Starlette 在未调用`websocket.accept()`之前调用了`websocket.close()`,会发生403 + assert exce.value.response.status_code == 400 + assert exce.value.response.headers["foo"] == "bar" + # XXX, HACK, TODO: Unable to read the content of WebSocketUpgradeError.response + # See: https://github.com/frankie567/httpx-ws/discussions/69 + # assert exce.value.response.json() == {"foo": "bar"} + ########## 客户端突然关闭时,服务器应该收到1011 ########## # NOTE: 这个测试不放在 `test_target_server_shutdown_abnormally` 来做 @@ -215,10 +282,15 @@ async def test_ws_proxy(self, tool_4_test_fixture: Tool4TestFixture) -> None: with pytest.raises(starlette_websockets_module.WebSocketDisconnect) as exce: await target_starlette_ws.receive_text() # receive_bytes() 也可以 + closing_event = target_starlette_ws.state.closing + assert isinstance(closing_event, anyio.Event) + closing_event.set() + # assert exce.value.code == 1011 # HACK, FIXME: 无法测试错误代码,似乎无法正常传递,且不同后端也不同 # FAILED test_ws_proxy[websockets] - assert 1005 == 1011 # FAILED test_ws_proxy[wsproto] - assert == 1011 + # NOTE: the close code for abnormal close is undefined behavior, so we won't test this # FIXME: 调查为什么收到关闭代码需要40s @pytest.mark.timeout(60) @@ -289,3 +361,5 @@ async def test_target_server_shutdown_abnormally(self) -> None: # 只要第二个客户端不是在之前40s基础上又重复40s,就暂时没问题, # 因为这模拟了多个客户端进行连接的情况。 assert (seconde_ws_recv_end - seconde_ws_recv_start) < 2 + + # NOTE: the close code for abnormal close is undefined behavior, so we won't test this From 3d2b0ea27a0b083886c1d03e37645bb5a7050dfa Mon Sep 17 00:00:00 2001 From: WSH032 <614337162@qq.com> Date: Mon, 8 Apr 2024 22:37:40 +0800 Subject: [PATCH 11/11] refactor: remove improperly exposed private module variables --- src/fastapi_proxy_lib/core/websocket.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/fastapi_proxy_lib/core/websocket.py b/src/fastapi_proxy_lib/core/websocket.py index 74ffb94..223c381 100644 --- a/src/fastapi_proxy_lib/core/websocket.py +++ b/src/fastapi_proxy_lib/core/websocket.py @@ -55,14 +55,13 @@ ) DEFAULT_QUEUE_SIZE = 512 # pyright: ignore[reportConstantRedefinition] - msg = dedent( - """\ - Can not import the default httpx_ws arguments, please open an issue on: - https://github.com/WSH032/fastapi-proxy-lib\ - """ - ) warnings.warn( - msg, + dedent( + """\ + Can not import the default httpx_ws arguments, please open an issue on: + https://github.com/WSH032/fastapi-proxy-lib\ + """ + ), RuntimeWarning, stacklevel=1, )