diff --git a/CHANGELOG.md b/CHANGELOG.md index 336f974..0fb92f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- [#41](https://github.com/WSH032/fastapi-proxy-lib/pull/41) - feat: add `callback` api to `WebSocketProxy`. Thanks [@WSH032](https://github.com/WSH032) and [@IvesAwadi](https://github.com/IvesAwadi)! + ### Changed - [#30](https://github.com/WSH032/fastapi-proxy-lib/pull/30) - fix(internal): use `websocket` in favor of `websocket_route`. Thanks [@WSH032](https://github.com/WSH032)! diff --git a/README.md b/README.md index 56d2f71..5c5da6c 100644 --- a/README.md +++ b/README.md @@ -32,8 +32,9 @@ Source Code: - [x] Support both **reverse** proxy and **forward** proxy. - [x] **Transparently** and **losslessly** handle all proxy requests, Including **HTTP headers**, **cookies**, **query parameters**, **body**, etc. +- [X] WebSocket proxy **callback**. - [x] Asynchronous streaming transfer, support **file proxy**. -- [x] `fastapi-proxy-lib` value [privacy security](https://wsh032.github.io/fastapi-proxy-lib/Usage/Security/). +- [x] `fastapi-proxy-lib` value [**privacy security**](https://wsh032.github.io/fastapi-proxy-lib/Usage/Security/). ### other features diff --git a/docs/Usage/Advanced.md b/docs/Usage/Advanced.md index b2e522d..0951399 100644 --- a/docs/Usage/Advanced.md +++ b/docs/Usage/Advanced.md @@ -4,7 +4,8 @@ For the following scenarios, you might prefer [fastapi_proxy_lib.core][]: - When you need to use proxies with **only** `Starlette` dependencies (without `FastAPI`). - When you need more fine-grained control over parameters and lifespan event. -- When you need to further process the input and output before and after the proxy (similar to middleware). +- When you need to further process the input and output before and after the http proxy (similar to `middleware`). +- When you need `callback` to modify the websocket proxy messages. We will demonstrate with `FastAPI`, but you can completely switch to the `Starlette` approach, @@ -19,13 +20,13 @@ Also (without annotations): - [`ForwardHttpProxy#examples`][fastapi_proxy_lib.core.http.ForwardHttpProxy--examples] - [`ReverseWebSocketProxy#examples`][fastapi_proxy_lib.core.websocket.ReverseWebSocketProxy--examples] -## Modify request +## Modify HTTP request In some cases, you may want to make final modifications before sending a request, such as performing behind-the-scenes authentication by modifying the headers of request. `httpx` provides comprehensive authentication support, and `fastapi-proxy-lib` offers first-class support for `httpx`. -See +See You can refer following example to implement a simple authentication: @@ -35,7 +36,7 @@ from fastapi_proxy_lib.fastapi.app import reverse_http_app class MyCustomAuth(httpx.Auth): - # ref: https://www.python-httpx.org/advanced/#customizing-authentication + # ref: https://www.python-httpx.org/advanced/authentication/ def __init__(self, token: str): self.token = token @@ -55,7 +56,7 @@ app = reverse_http_app( visit `/headers` to see the result which contains `"X-Authentication": "bearer_token"` header. -## Modify response +## Modify HTTP response In some cases, you may want to make final modifications before return the response to the client, such as transcoding video response streams. @@ -118,3 +119,19 @@ async def _(request: Request, path: str = ""): ``` visit `/`, you will notice that the response body is printed to the console. + +## Modify WebSocket message + +In some cases, you might want to modify the content of the messages that the WebSocket proxy receives and sends to the client and target server. + +In version `0.2.0` of `fastapi-proxy-lib`, we introduced a [`callback API`][fastapi_proxy_lib.core.websocket.ReverseWebSocketProxy.proxy] for `WebSocketProxy` to allow you to do this. + +See example: [ReverseWebSocketProxy#with-callback][fastapi_proxy_lib.core.websocket.ReverseWebSocketProxy--with-callback] + +Also: + +- RFC: [#40](https://github.com/WSH032/fastapi-proxy-lib/issues/40) +- PR: [#41](https://github.com/WSH032/fastapi-proxy-lib/pull/41) + +!!!example + The current implementation still has some defects. Read the [callback-implementation][fastapi_proxy_lib.core.websocket.BaseWebSocketProxy.send_request_to_target--callback-implementation] section, or you might accidentally shoot yourself in the foot. diff --git a/docs/Usage/FastAPI-Helper.md b/docs/Usage/FastAPI-Helper.md index 45f1ebe..9a25600 100644 --- a/docs/Usage/FastAPI-Helper.md +++ b/docs/Usage/FastAPI-Helper.md @@ -10,7 +10,7 @@ There are two helper modules to get FastAPI `app`/`router` for proxy convenientl ## app -use `fastapi_proxy_lib.fastapi.app` is very convenient and out of the box, there are three helper functions: +`fastapi_proxy_lib.fastapi.app` is very convenient and out of the box, there are three helper functions: - [forward_http_app][fastapi_proxy_lib.fastapi.app.forward_http_app] - [reverse_http_app][fastapi_proxy_lib.fastapi.app.reverse_http_app] @@ -46,3 +46,9 @@ For the following scenarios, you might prefer [fastapi_proxy_lib.fastapi.router] - When you need to [mount the proxy on a route of larger app](https://fastapi.tiangolo.com/tutorial/bigger-applications/). **^^[Please refer to the documentation of `RouterHelper` for more information :material-file-document: ][fastapi_proxy_lib.fastapi.router.RouterHelper--examples]^^**. + +--- + +## More + +**The `Helper Module` might not meet your further customization needs. Please refer to the [Advanced](Advanced.md) section, which is the core of `fastapi-proxy-lib`, for more personalization options.** diff --git a/mkdocs.yml b/mkdocs.yml index d4c2e3a..26884ad 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -108,6 +108,8 @@ plugins: import: - https://frankie567.github.io/httpx-ws/objects.inv - https://fastapi.tiangolo.com/objects.inv + - https://anyio.readthedocs.io/en/stable/objects.inv + - https://docs.python.org/3/objects.inv options: docstring_style: google paths: [src] diff --git a/pyproject.toml b/pyproject.toml index f3c9bbf..f545340 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,8 @@ dependencies = [ "httpx", "httpx-ws >= 0.4.2", "starlette", - "typing_extensions >=4.5.0", + "typing_extensions >= 4.12", + "anyio >= 4", ] [project.optional-dependencies] @@ -91,13 +92,12 @@ dependencies = [ # NOTE: 👆 # lint-check - "pyright == 1.1.356", # pyright must be installed in the runtime environment + "pyright == 1.1.372", # pyright must be installed in the runtime environment # test "pytest == 7.*", "pytest-cov == 4.*", "uvicorn[standard] < 1.0.0", # TODO: Once it releases version 1.0.0, we will remove this restriction. "httpx[http2]", # we don't set version here, instead set it in `[project].dependencies`. - "anyio", # we don't set version here, because fastapi has a dependency on it "asgi-lifespan==2.*", "pytest-timeout==2.*", ] diff --git a/src/fastapi_proxy_lib/core/http.py b/src/fastapi_proxy_lib/core/http.py index fab3316..e891502 100644 --- a/src/fastapi_proxy_lib/core/http.py +++ b/src/fastapi_proxy_lib/core/http.py @@ -222,7 +222,7 @@ class BaseHttpProxy(BaseProxyModel): """ @override - async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOverride] + async def send_request_to_target( self, *, request: StarletteRequest, target_url: httpx.URL ) -> StarletteResponse: """Change request headers and send request to target url. @@ -318,6 +318,8 @@ class ReverseHttpProxy(BaseHttpProxy): # # Examples + ## Basic usage + ```python from contextlib import asynccontextmanager from typing import AsyncIterator @@ -341,7 +343,7 @@ async def close_proxy_event(_: FastAPI) -> AsyncIterator[None]: # (1)! async def _(request: Request, path: str = ""): return await proxy.proxy(request=request, path=path) # (3)! - # Then run shell: `uvicorn :app --host http://127.0.0.1:8000 --port 8000` + # Then run shell: `uvicorn :app --host 127.0.0.1 --port 8000` # visit the app: `http://127.0.0.1:8000/` # you will get the response from `http://www.example.com/` ``` @@ -452,6 +454,8 @@ class ForwardHttpProxy(BaseHttpProxy): # # Examples + ## Basic usage + ```python from contextlib import asynccontextmanager from typing import AsyncIterator @@ -476,7 +480,7 @@ async def close_proxy_event(_: FastAPI) -> AsyncIterator[None]: async def _(request: Request, path: str = ""): return await proxy.proxy(request=request, path=path) - # Then run shell: `uvicorn :app --host http://127.0.0.1:8000 --port 8000` + # Then run shell: `uvicorn :app --host 127.0.0.1 --port 8000` # visit the app: `http://127.0.0.1:8000/http://www.example.com` # you will get the response from `http://www.example.com` ``` diff --git a/src/fastapi_proxy_lib/core/websocket.py b/src/fastapi_proxy_lib/core/websocket.py index 36dacd7..5c343a2 100644 --- a/src/fastapi_proxy_lib/core/websocket.py +++ b/src/fastapi_proxy_lib/core/websocket.py @@ -2,21 +2,36 @@ import asyncio import logging -from contextlib import AsyncExitStack +import warnings +from contextlib import ( + AbstractContextManager, + AsyncExitStack, + ExitStack, + asynccontextmanager, + nullcontext, +) from typing import ( TYPE_CHECKING, Any, + AsyncIterator, + Callable, + ContextManager, + Coroutine, + Generic, List, Literal, NamedTuple, NoReturn, Optional, + Tuple, Union, ) import httpx import httpx_ws import starlette.websockets as starlette_ws +from anyio import create_memory_object_stream +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_ws._api import ( # HACK: 注意,这个是私有模块 DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, @@ -29,6 +44,7 @@ from starlette.responses import StreamingResponse from starlette.types import Scope from typing_extensions import TypeAlias, override +from typing_extensions import TypeVar as TypeVarExt from wsproto.events import BytesMessage as WsprotoBytesMessage from wsproto.events import TextMessage as WsprotoTextMessage @@ -39,10 +55,7 @@ check_http_version, ) -__all__ = ( - "BaseWebSocketProxy", - "ReverseWebSocketProxy", -) +__all__ = ("BaseWebSocketProxy", "ReverseWebSocketProxy", "CallbackPipeContextType") if TYPE_CHECKING: # 这些是私有模块,无法确定以后版本是否会改变,为了保证运行时不会出错,我们使用TYPE_CHECKING @@ -56,6 +69,39 @@ _ServerToClientTaskType: TypeAlias = "asyncio.Task[httpx_ws.WebSocketDisconnect]" +_DefaultWsMsgType = Union[str, bytes] +_WsMsgType = Union[str, bytes, _DefaultWsMsgType] +"""Websocket message type.""" +_WsMsgTypeVar = TypeVarExt("_WsMsgTypeVar", bound=_WsMsgType, default=_DefaultWsMsgType) +"""Generic websocket message type.""" +_CallbackPipeType = Tuple[ + MemoryObjectSendStream[_WsMsgTypeVar], MemoryObjectReceiveStream[_WsMsgTypeVar] +] +"""Send end and receive end of a callback pipe.""" +CallbackPipeContextType = ContextManager[_CallbackPipeType[_WsMsgTypeVar]] +"""A context manager that will automatically close the pipe when exit. + +The `__enter__` method will return one end of each pair of +[`memory-object-streams`](https://anyio.readthedocs.io/en/stable/streams.html#memory-object-streams) pipelines. + +See example: [ReverseWebSocketProxy#with-callback][fastapi_proxy_lib.core.websocket.ReverseWebSocketProxy--with-callback] + +Warning: + This is a unstable public type hint, you shouldn't rely on it. + You should create your own type hint instead. + +Note: + You must ensure that **exit** the context manager, + or maybe you will get a **deadlock**. + + See: [`callback-implementation`][fastapi_proxy_lib.core.websocket.BaseWebSocketProxy.send_request_to_target--callback-implementation] +""" +_CallbackType = Callable[ + [CallbackPipeContextType[_WsMsgTypeVar]], Coroutine[None, None, None] +] +"""The websocket callback provided by user.""" + + class _ClientServerProxyTask(NamedTuple): """The task group for passing websocket message between client and target server.""" @@ -70,6 +116,8 @@ class _ClientServerProxyTask(NamedTuple): SUPPORTED_WS_HTTP_VERSIONS = ("1.1",) """The http versions that we supported now. It depends on `httpx`.""" +_CALLBACK_BUFFER_SIZE = 0 +"""The buffer size of the callback pipe.""" #################### Error #################### @@ -97,6 +145,84 @@ def _get_client_request_subprotocols(ws_scope: Scope) -> Union[List[str], None]: return subprotocols +class _PipeContextBuilder( + # before py3.9, `AbstractContextManager` is not generic + AbstractContextManager, # pyright: ignore[reportMissingTypeArgument] + Generic[_WsMsgTypeVar], +): + """Auto close the pipe when exit the context.""" + + def __init__(self, raw_pipe: _CallbackPipeType[_WsMsgTypeVar]) -> None: + self._raw_pipe = raw_pipe + self._exit_stack = ExitStack() + self._exited = False + + def __enter__(self) -> _CallbackPipeType[_WsMsgTypeVar]: + sender, receiver = self._raw_pipe + self._exit_stack.enter_context(sender) + self._exit_stack.enter_context(receiver) + return self._raw_pipe + + def __exit__(self, *_: Any) -> None: + self._exit_stack.__exit__(*_) + self._exited = True + + def __del__(self): + if not self._exited: + warnings.warn( + "You never exit the pipe context, it may cause a deadlock.", + RuntimeWarning, + stacklevel=1, + ) + + +@asynccontextmanager +async def _wait_task_and_ignore_exce( + callback_task: "asyncio.Task[Any]", +) -> AsyncIterator[None]: + """Wait for the task when exiting, but ignore the exception.""" + yield + try: + await callback_task + except Exception: + pass + + +async def _enable_callback( + callback: _CallbackType[_WsMsgTypeVar], task_name: str, exit_stack: AsyncExitStack +) -> CallbackPipeContextType[_WsMsgTypeVar]: + """Create a task to run the callback. + + - The callback task will be awaited when exit the `exit_stack`, + but the exception of callback task will be ignored. + - When the callback done(normal or exception), + the pipe used in the callback will be closed; + this is for preventing callback forgetting to close the pipe and causing deadlock. + ```py + async def callback(ctx: CallbackPipeContextType[str]) -> None: + pass + ``` + NOTE: This is just a mitigation measure and may be removed in the future, + so it needs to be documented in the public documentation. + """ + proxy_sender, cb_receiver = create_memory_object_stream[_WsMsgTypeVar]( + _CALLBACK_BUFFER_SIZE + ) + cb_sender, proxy_receiver = create_memory_object_stream[_WsMsgTypeVar]( + _CALLBACK_BUFFER_SIZE + ) + cb_pipe_ctx = _PipeContextBuilder((cb_sender, cb_receiver)) + proxy_pipe_ctx = _PipeContextBuilder((proxy_sender, proxy_receiver)) + + cb_task = asyncio.create_task(callback(cb_pipe_ctx), name=task_name) + # use `done_callback` to close the pipe when the callback task is done, + # NOTE: we close `sender` and `receiver` directly, instead of using `cb_pipe_ctx.__exit__`, + # so that `_PipeContextBuilder` issue a warning when the callback forgets to close the pipe. + cb_task.add_done_callback(lambda _: (cb_sender.close(), cb_receiver.close())) + await exit_stack.enter_async_context(_wait_task_and_ignore_exce(cb_task)) + return proxy_pipe_ctx + + # TODO: 等待starlette官方的支持 # 为什么使用这个函数而不是直接使用starlette_WebSocket.receive_text() # 请看: https://github.com/encode/starlette/discussions/2310 @@ -261,13 +387,19 @@ async def _starlette_ws_send_bytes_or_str( async def _wait_client_then_send_to_server( - *, client_ws: starlette_ws.WebSocket, server_ws: httpx_ws.AsyncWebSocketSession + *, + client_ws: starlette_ws.WebSocket, + server_ws: httpx_ws.AsyncWebSocketSession, + pipe_context: Optional[CallbackPipeContextType[_WsMsgTypeVar]] = None, ) -> starlette_ws.WebSocketDisconnect: """Receive data from client, then send to target server. Args: client_ws: The websocket which receive data of client. server_ws: The websocket which send data to target server. + pipe_context: The callback pipe for processing data. + will send the received data(from client) to the sender, + and receive the data from the receiver(then send to the server). Returns: If the client_ws sends a shutdown message normally, will return starlette_ws.WebSocketDisconnect. @@ -275,24 +407,39 @@ async def _wait_client_then_send_to_server( Raises: error for receiving: refer to `_starlette_ws_receive_bytes_or_str` error for sending: refer to `_httpx_ws_send_bytes_or_str` + error for callback: refer to `MemoryObjectReceiveStream.receive` and `MemoryObjectSendStream.send` """ - while True: - try: - receive = await _starlette_ws_receive_bytes_or_str(client_ws) - except starlette_ws.WebSocketDisconnect as e: - return e - else: + with pipe_context or nullcontext() as pipe: + while True: + try: + receive = await _starlette_ws_receive_bytes_or_str(client_ws) + except starlette_ws.WebSocketDisconnect as e: + return e + + # TODO: do not use `if` statement in loop + if pipe is not None: + sender, receiver = pipe + # XXX, HACK, TODO: We can't identify the msg type from websocket, + # so we have to igonre the type check here. + await sender.send(receive) # pyright: ignore [reportArgumentType] + receive = await receiver.receive() await _httpx_ws_send_bytes_or_str(server_ws, receive) async def _wait_server_then_send_to_client( - *, client_ws: starlette_ws.WebSocket, server_ws: httpx_ws.AsyncWebSocketSession + *, + client_ws: starlette_ws.WebSocket, + server_ws: httpx_ws.AsyncWebSocketSession, + pipe_context: Optional[CallbackPipeContextType[_WsMsgTypeVar]] = None, ) -> httpx_ws.WebSocketDisconnect: """Receive data from target server, then send to client. Args: client_ws: The websocket which send data to client. server_ws: The websocket which receive data of target server. + pipe_context: The callback pipe for processing data. + will send the received data(from server) to the sender, + and receive the data from the receiver(then send to the client). Returns: If the server_ws sends a shutdown message normally, will return httpx_ws.WebSocketDisconnect. @@ -300,13 +447,22 @@ async def _wait_server_then_send_to_client( Raises: error for receiving: refer to `_httpx_ws_receive_bytes_or_str` error for sending: refer to `_starlette_ws_send_bytes_or_str` + error for callback: refer to `MemoryObjectReceiveStream.receive` and `MemoryObjectSendStream.send` """ - while True: - try: - receive = await _httpx_ws_receive_bytes_or_str(server_ws) - except httpx_ws.WebSocketDisconnect as e: - return e - else: + with pipe_context or nullcontext() as pipe: + while True: + try: + receive = await _httpx_ws_receive_bytes_or_str(server_ws) + except httpx_ws.WebSocketDisconnect as e: + return e + + # TODO: do not use `if` statement in loop + if pipe is not None: + sender, receiver = pipe + # XXX, HACK, TODO: We can't identify the msg type from websocket, + # so we have to igonre the type check here. + await sender.send(receive) # pyright: ignore [reportArgumentType] + receive = await receiver.receive() await _starlette_ws_send_bytes_or_str(client_ws, receive) @@ -396,6 +552,14 @@ async def _close_ws( #################### # #################### +_WsMsgTypeVar_CTS = TypeVarExt( + "_WsMsgTypeVar_CTS", bound=_WsMsgType, default=_DefaultWsMsgType +) +_WsMsgTypeVar_STC = TypeVarExt( + "_WsMsgTypeVar_STC", bound=_WsMsgType, default=_DefaultWsMsgType +) + + class BaseWebSocketProxy(BaseProxyModel): """Websocket proxy base class. @@ -408,7 +572,7 @@ class BaseWebSocketProxy(BaseProxyModel): keepalive_ping_timeout_seconds: refer to [httpx_ws.aconnect_ws][] Tip: - [`httpx_ws.aconnect_ws`](https://frankie567.github.io/httpx-ws/reference/httpx_ws/#httpx_ws.aconnect_ws) + [httpx_ws.aconnect_ws][] """ client: httpx.AsyncClient @@ -447,7 +611,7 @@ def __init__( keepalive_ping_timeout_seconds: refer to [httpx_ws.aconnect_ws][] Tip: - [`httpx_ws.aconnect_ws`](https://frankie567.github.io/httpx-ws/reference/httpx_ws/#httpx_ws.aconnect_ws) + [httpx_ws.aconnect_ws][] """ self.max_message_size_bytes = max_message_size_bytes self.queue_size = queue_size @@ -456,11 +620,13 @@ def __init__( super().__init__(client, follow_redirects=follow_redirects) @override - async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOverride] + async def send_request_to_target( self, *, websocket: starlette_ws.WebSocket, target_url: httpx.URL, + client_to_server_callback: Optional[_CallbackType[_WsMsgTypeVar_CTS]] = None, + server_to_client_callback: Optional[_CallbackType[_WsMsgTypeVar_STC]] = None, ) -> Union[Literal[False], StarletteResponse]: """Establish websocket connection for both client and target_url, then pass messages between them. @@ -469,6 +635,76 @@ async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOv Args: websocket: The client websocket requests. target_url: The url of target websocket server. + client_to_server_callback: The callback function for processing data from client to server. + The usage example is in subclass [ReverseWebSocketProxy#with-callback][fastapi_proxy_lib.core.websocket.ReverseWebSocketProxy--with-callback]. + server_to_client_callback: The callback function for processing data from server to client. + The usage example is in subclass [ReverseWebSocketProxy#with-callback][fastapi_proxy_lib.core.websocket.ReverseWebSocketProxy--with-callback]. + + ## Callback implementation + + Note: The `callback` implementation details: + - If `callback` is not None, will create a new task to run the callback function. + When the whole proxy task finishes, the callback task will be awaited, but the exception will be ignored. + - `callback` must ensure that it closes the pipe(exit the pipe context) when it finishes or encounters an exception, + or you will get a **deadlock**. + - A common mistake is the `callback` encountering an exception or returning before entering the context. + ```py + async def callback(ctx: CallbackPipeContextType[str]) -> None: + # mistake: not entering the context + return + + async def callback(ctx: CallbackPipeContextType[str]) -> None: + # mistake: encountering an exception before entering the context + 1 / 0 + with ctx as (sender, receiver): + pass + ``` + - If the callback-side pipe is closed but proxy task is still running, + `proxy` will treat it as a exception and close websocket connection. + - If `proxy` encounters an exception or receives a disconnection request, the proxy-side pipe will be closed, + then the callback-side pipe will receive an exception + (refer to [send][anyio.streams.memory.MemoryObjectSendStream.send], + also [receive][anyio.streams.memory.MemoryObjectReceiveStream.receive]). + - **The buffer size of the pipe is currently 0 (this may change in the future)**, + which means that if the `callback` does not call the receiver, the `WebSocket` will block. + Therefore, the `callback` should not take too long to process a single message. + If you expect to be unable to call the receiver for an extended period, + you need to create your own buffer to store messages. + + See also: + + - [RFC#40](https://github.com/WSH032/fastapi-proxy-lib/issues/40) + - [memory-object-streams](https://anyio.readthedocs.io/en/stable/streams.html#memory-object-streams) + + Bug: Dead lock + The current implementation only supports a strict `one-receive-one-send` mode within a single loop. + If this pattern is violated, such as `multiple receives and one send`, `one receive and multiple sends`, + or `sending before receiving` within a single loop, it will result in a deadlock. + + See Issue Tracker: [#42](https://github.com/WSH032/fastapi-proxy-lib/issues/42) + + ```py + async def callback(ctx: CallbackPipeContextType[str]) -> None: + with ctx as (sender, receiver): + # multiple receives and one send, dead lock! + await receiver.receive() + await receiver.receive() + await sender.send("foo") + + async def callback(ctx: CallbackPipeContextType[str]) -> None: + with ctx as (sender, receiver): + # one receive and multiple sends, dead lock! + async for message in receiver: + await sender.send("foo") + await sender.send("bar") + + async def callback(ctx: CallbackPipeContextType[str]) -> None: + with ctx as (sender, receiver): + # sending before receiving, dead lock! + await sender.send("foo") + async for message in receiver: + await sender.send(message) + ``` Returns: If the establish websocket connection unsuccessfully: @@ -559,6 +795,9 @@ async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOv # NOTE: 对于反向代理服务器,我们不返回 "任何" "具体的内部" 错误信息给客户端,因为这可能涉及到服务器内部的信息泄露 + # NOTE: Do not exit the `stack` before return `StreamingResponse` above. + # Because once the `stack` close, the `httpx_ws` connection will be closed, + # then the streaming response will encounter an error. # NOTE: 请使用 with 语句来 "保证关闭" AsyncWebSocketSession async with stack: # TODO: websocket.accept 中还有一个headers参数,但是httpx_ws不支持,考虑发起PR @@ -573,17 +812,34 @@ async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOv # headers=... ) + cts_pipe_ctx = ( + await _enable_callback( + client_to_server_callback, "client_to_server_callback", stack + ) + if client_to_server_callback is not None + else None + ) client_to_server_task = asyncio.create_task( _wait_client_then_send_to_server( client_ws=websocket, server_ws=proxy_ws, + pipe_context=cts_pipe_ctx, ), name="client_to_server_task", ) + + stc_pipe_ctx = ( + await _enable_callback( + server_to_client_callback, "server_to_client_callback", stack + ) + if server_to_client_callback is not None + else None + ) server_to_client_task = asyncio.create_task( _wait_server_then_send_to_client( client_ws=websocket, server_ws=proxy_ws, + pipe_context=stc_pipe_ctx, ), name="server_to_client_task", ) @@ -610,6 +866,8 @@ async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOv # 因为第二种情况的存在,所以需要用 wait_for 强制让其退出 # 但考虑到第一种情况,先等它 1s ,看看能否正常退出 try: + # Here we just cancel the task that is not finished, + # but we don't handle websocket closing here. _, pending = await asyncio.wait( task_group, return_when=asyncio.FIRST_COMPLETED, @@ -666,7 +924,7 @@ class ReverseWebSocketProxy(BaseWebSocketProxy): keepalive_ping_timeout_seconds: refer to [httpx_ws.aconnect_ws][] Tip: - [`httpx_ws.aconnect_ws`](https://frankie567.github.io/httpx-ws/reference/httpx_ws/#httpx_ws.aconnect_ws) + [httpx_ws.aconnect_ws][] Bug: There is a issue for handshake response: This WebSocket proxy can correctly forward request headers. @@ -681,6 +939,8 @@ class ReverseWebSocketProxy(BaseWebSocketProxy): # # Examples + ## Basic usage + ```python from contextlib import asynccontextmanager from typing import AsyncIterator @@ -704,10 +964,58 @@ async def close_proxy_event(_: FastAPI) -> AsyncIterator[None]: async def _(websocket: WebSocket, path: str = ""): return await proxy.proxy(websocket=websocket, path=path) - # Then run shell: `uvicorn :app --host http://127.0.0.1:8000 --port 8000` + # Then run shell: `uvicorn :app --host 127.0.0.1 --port 8000` # visit the app: `ws://127.0.0.1:8000/` # you can establish websocket connection with `ws://echo.websocket.events` ``` + + ## With callback + + Tip: + See also: [`callback-implementation`][fastapi_proxy_lib.core.websocket.BaseWebSocketProxy.send_request_to_target--callback-implementation] + + ```python + # NOTE: `CallbackPipeContextType` is a unstable public type hint, + # you shouldn't rely on it. + # You should create your own type hint instead. + from fastapi_proxy_lib.core.websocket import CallbackPipeContextType + + # NOTE: Providing a specific type annotation for `CallbackPipeContextType` + # does not offer any runtime guarantees. + # This means that even if you annotate it as `[str]`, + # you may still receive `bytes` types, + # unless you are certain that the other end will only send `str`. + # The default generic is `[str | bytes]`. + + async def client_to_server_callback(pipe_context: CallbackPipeContextType[str]) -> None: + with pipe_context as (sender, receiver): + async for message in receiver: + print(f"Received from client: {message}") + # here we modify the message with `CTS:` prefix + await sender.send(f"CTS:{message}") + # If `proxy` receives a disconnection request and websocket closed correctly, + # this message will be printed. Or not, if exception occurs. + print("client_to_server_callback end") + + # we give a `bytes` type hint here, but it has no runtime guarantee, + # just make type-checker happy + async def server_to_client_callback(pipe_context: CallbackPipeContextType[bytes]) -> None: + with pipe_context as (sender, receiver): + async for message in receiver: + print(f"Received from server: {message}") + await sender.send(f"STC:{message}".encode()) # `bytes` here + print("server_to_client_callback end") + + @app.websocket("/{path:path}") + async def _(websocket: WebSocket, path: str = ""): + return await proxy.proxy( + websocket=websocket, + path=path, + # register callback + client_to_server_callback=client_to_server_callback, + server_to_client_callback=server_to_client_callback, + ) + ``` ''' client: httpx.AsyncClient @@ -753,7 +1061,7 @@ def __init__( keepalive_ping_timeout_seconds: refer to [httpx_ws.aconnect_ws][] Tip: - [`httpx_ws.aconnect_ws`](https://frankie567.github.io/httpx-ws/reference/httpx_ws/#httpx_ws.aconnect_ws) + [httpx_ws.aconnect_ws][] """ self.base_url = check_base_url(base_url) super().__init__( @@ -767,7 +1075,12 @@ def __init__( @override async def proxy( # pyright: ignore [reportIncompatibleMethodOverride] - self, *, websocket: starlette_ws.WebSocket, path: Optional[str] = None + self, + *, + websocket: starlette_ws.WebSocket, + path: Optional[str] = None, + client_to_server_callback: Optional[_CallbackType[_WsMsgTypeVar_CTS]] = None, + server_to_client_callback: Optional[_CallbackType[_WsMsgTypeVar_STC]] = None, ) -> Union[Literal[False], StarletteResponse]: """Establish websocket connection for both client and target_url, then pass messages between them. @@ -776,6 +1089,12 @@ async def proxy( # pyright: ignore [reportIncompatibleMethodOverride] path: The path params of websocket request, which means the path params of base url.
If None, will get it from `websocket.path_params`.
**Usually, you don't need to pass this argument**. + client_to_server_callback: The callback function for processing data from client to server. + The implementation details are in the base class + [`BaseWebSocketProxy.send_request_to_target`][fastapi_proxy_lib.core.websocket.BaseWebSocketProxy.send_request_to_target--callback-implementation]. + server_to_client_callback: The callback function for processing data from server to client. + The implementation details are in the base class + [`BaseWebSocketProxy.send_request_to_target`][fastapi_proxy_lib.core.websocket.BaseWebSocketProxy.send_request_to_target--callback-implementation]. Returns: If the establish websocket connection unsuccessfully: @@ -800,5 +1119,8 @@ async def proxy( # pyright: ignore [reportIncompatibleMethodOverride] # self.send_request_to_target 内部会处理连接失败时,返回错误给客户端,所以这里不处理了 return await self.send_request_to_target( - websocket=websocket, target_url=target_url + websocket=websocket, + target_url=target_url, + client_to_server_callback=client_to_server_callback, + server_to_client_callback=server_to_client_callback, ) diff --git a/tests/test_docs_examples.py b/tests/test_docs_examples.py index 12bea06..91cbf91 100644 --- a/tests/test_docs_examples.py +++ b/tests/test_docs_examples.py @@ -26,7 +26,7 @@ async def close_proxy_event(_: FastAPI) -> AsyncIterator[None]: async def _(request: Request, path: str = ""): return await proxy.proxy(request=request, path=path) - # Then run shell: `uvicorn :app --host http://127.0.0.1:8000 --port 8000` + # Then run shell: `uvicorn :app --host 127.0.0.1 --port 8000` # visit the app: `http://127.0.0.1:8000/http://www.example.com` # you will get the response from `http://www.example.com` @@ -55,7 +55,7 @@ async def close_proxy_event(_: FastAPI) -> AsyncIterator[None]: # (1)! async def _(request: Request, path: str = ""): return await proxy.proxy(request=request, path=path) # (3)! - # Then run shell: `uvicorn :app --host http://127.0.0.1:8000 --port 8000` + # Then run shell: `uvicorn :app --host 127.0.0.1 --port 8000` # visit the app: `http://127.0.0.1:8000/` # you will get the response from `http://www.example.com/` @@ -93,7 +93,7 @@ async def close_proxy_event(_: FastAPI) -> AsyncIterator[None]: async def _(websocket: WebSocket, path: str = ""): return await proxy.proxy(websocket=websocket, path=path) - # Then run shell: `uvicorn :app --host http://127.0.0.1:8000 --port 8000` + # Then run shell: `uvicorn :app --host 127.0.0.1 --port 8000` # visit the app: `ws://127.0.0.1:8000/` # you can establish websocket connection with `ws://echo.websocket.events` diff --git a/tests/test_ws.py b/tests/test_ws.py index 2119719..3a8e464 100644 --- a/tests/test_ws.py +++ b/tests/test_ws.py @@ -1,18 +1,30 @@ # noqa: D100 - import asyncio -from contextlib import AsyncExitStack +import gc +from contextlib import AsyncExitStack, asynccontextmanager from multiprocessing import Process, Queue -from typing import Any, Dict, Literal, Optional +from typing import Any, AsyncIterator, Dict, Literal, Optional, Protocol +import anyio import httpx import httpx_ws import pytest import uvicorn +from anyio import move_on_after +from fastapi import FastAPI +from fastapi_proxy_lib.core.websocket import ( + CallbackPipeContextType, + ReverseWebSocketProxy, + # from this project, so it's ok + _CallbackType, # pyright: ignore[reportPrivateUsage] + _WsMsgTypeVar_CTS, # pyright: ignore[reportPrivateUsage] + _WsMsgTypeVar_STC, # pyright: ignore[reportPrivateUsage] +) from fastapi_proxy_lib.fastapi.app import reverse_ws_app as get_reverse_ws_app from httpx_ws import aconnect_ws from starlette import websockets as starlette_websockets_module +from starlette.websockets import WebSocket as StarletteWebSocket from typing_extensions import override from .app.echo_ws_app import get_app as get_ws_test_app @@ -37,6 +49,22 @@ NO_PROXIES: Dict[Any, Any] = {"all://": None} +class WsAppFactory(Protocol): + """The factory for `Tool4TestFixture`.""" + + def __call__(self, client: httpx.AsyncClient, *, base_url: str) -> FastAPI: + """Return the ws proxy app for testing.""" + ... + + +class Tool4TestFixtureFactory(Protocol): + """The factory for `Tool4TestFixture`.""" + + async def __call__(self, ws_app_factory: WsAppFactory) -> Tool4TestFixture: + """See the implementation for details.""" + ... + + def _subprocess_run_echo_ws_uvicorn_server(queue: "Queue[str]", **kwargs: Any): """Run echo ws app in subprocess. @@ -101,54 +129,109 @@ async def run(): asyncio.run(run()) +def callback_ws_app_factory_builder( + client_to_server_callback: Optional[_CallbackType[_WsMsgTypeVar_CTS]] = None, + server_to_client_callback: Optional[_CallbackType[_WsMsgTypeVar_STC]] = None, +) -> WsAppFactory: + """Return a ws proxy app factory with callback.""" + + def callback_ws_app_factory(client: httpx.AsyncClient, *, base_url: str) -> FastAPI: + """Return a ws proxy app with callback.""" + proxy = ReverseWebSocketProxy( + client=client, + base_url=base_url, + ) + + @asynccontextmanager + async def close_proxy_event(_: FastAPI) -> AsyncIterator[None]: + """Close proxy.""" + yield + await proxy.aclose() + + app = FastAPI(lifespan=close_proxy_event) + + @app.websocket("/{path:path}") + async def _(websocket: StarletteWebSocket, path: str = ""): + return await proxy.proxy( + websocket=websocket, + path=path, + client_to_server_callback=client_to_server_callback, + server_to_client_callback=server_to_client_callback, + ) + + return app + + return callback_ws_app_factory + + class TestReverseWsProxy(AbstractTestProxy): """For testing reverse websocket proxy.""" - @override @pytest.fixture(params=WS_BACKENDS_NEED_BE_TESTED) - async def tool_4_test_fixture( # pyright: ignore[reportIncompatibleMethodOverride] + def tool_4_test_fixture_factory( self, uvicorn_server_fixture: UvicornServerFixture, request: pytest.FixtureRequest, - ) -> Tool4TestFixture: + ) -> Tool4TestFixtureFactory: """目标服务器请参考`tests.app.echo_ws_app.get_app`.""" - echo_ws_test_model = get_ws_test_app() - echo_ws_app = echo_ws_test_model.app - echo_ws_get_request = echo_ws_test_model.get_request - target_ws_server = await uvicorn_server_fixture( - uvicorn.Config( - echo_ws_app, port=DEFAULT_PORT, host=DEFAULT_HOST, ws=request.param - ), - contx_exit_timeout=DEFAULT_CONTX_EXIT_TIMEOUT, - ) + async def _tool_4_test_fixture_factory( + ws_app_factory: WsAppFactory, + ) -> Tool4TestFixture: + """Create a tool for testing reverse websocket proxy. + + Args: + ws_app_factory: A app factory for create reverse proxy websocket app, + """ + echo_ws_test_model = get_ws_test_app() + echo_ws_app = echo_ws_test_model.app + echo_ws_get_request = echo_ws_test_model.get_request + + target_ws_server = await uvicorn_server_fixture( + uvicorn.Config( + echo_ws_app, port=DEFAULT_PORT, host=DEFAULT_HOST, ws=request.param + ), + contx_exit_timeout=DEFAULT_CONTX_EXIT_TIMEOUT, + ) - target_server_base_url = str(target_ws_server.contx_socket_url) + target_server_base_url = str(target_ws_server.contx_socket_url) - client_for_conn_to_target_server = httpx.AsyncClient(proxies=NO_PROXIES) + client_for_conn_to_target_server = httpx.AsyncClient(proxies=NO_PROXIES) - reverse_ws_app = get_reverse_ws_app( - client=client_for_conn_to_target_server, base_url=target_server_base_url - ) + reverse_ws_app = ws_app_factory( + client=client_for_conn_to_target_server, base_url=target_server_base_url + ) - proxy_ws_server = await uvicorn_server_fixture( - uvicorn.Config( - reverse_ws_app, port=DEFAULT_PORT, host=DEFAULT_HOST, ws=request.param - ), - contx_exit_timeout=DEFAULT_CONTX_EXIT_TIMEOUT, - ) + proxy_ws_server = await uvicorn_server_fixture( + uvicorn.Config( + reverse_ws_app, + port=DEFAULT_PORT, + host=DEFAULT_HOST, + ws=request.param, + ), + contx_exit_timeout=DEFAULT_CONTX_EXIT_TIMEOUT, + ) - proxy_server_base_url = str(proxy_ws_server.contx_socket_url) + proxy_server_base_url = str(proxy_ws_server.contx_socket_url) - client_for_conn_to_proxy_server = httpx.AsyncClient(proxies=NO_PROXIES) + client_for_conn_to_proxy_server = httpx.AsyncClient(proxies=NO_PROXIES) - return Tool4TestFixture( - client_for_conn_to_target_server=client_for_conn_to_target_server, - client_for_conn_to_proxy_server=client_for_conn_to_proxy_server, - get_request=echo_ws_get_request, - target_server_base_url=target_server_base_url, - proxy_server_base_url=proxy_server_base_url, - ) + return Tool4TestFixture( + client_for_conn_to_target_server=client_for_conn_to_target_server, + client_for_conn_to_proxy_server=client_for_conn_to_proxy_server, + get_request=echo_ws_get_request, + target_server_base_url=target_server_base_url, + proxy_server_base_url=proxy_server_base_url, + ) + + return _tool_4_test_fixture_factory + + @override + @pytest.fixture() + async def tool_4_test_fixture( # pyright: ignore[reportIncompatibleMethodOverride] + self, tool_4_test_fixture_factory: Tool4TestFixtureFactory + ) -> Tool4TestFixture: + return await tool_4_test_fixture_factory(get_reverse_ws_app) @pytest.mark.anyio() async def test_ws_proxy(self, tool_4_test_fixture: Tool4TestFixture) -> None: @@ -313,3 +396,160 @@ async def test_target_server_shutdown_abnormally( # 只要第二个客户端不是在之前40s基础上又重复40s,就暂时没问题, # 因为这模拟了多个客户端进行连接的情况。 assert (seconde_ws_recv_end - seconde_ws_recv_start) < 2 + + @pytest.mark.timeout(15) # prevent dead lock + @pytest.mark.anyio() + async def test_ws_proxy_with_callback( + self, tool_4_test_fixture_factory: Tool4TestFixtureFactory + ) -> None: + """Test reverse websocket proxy with callback.""" + msg = "foo" + cts_prefix = "CTS:" + stc_prefix = "STC:" + + cts_cb_receive = None + stc_cb_receive = None + cts_cb_done = anyio.Event() + stc_cb_done = anyio.Event() + + async def client_to_server_callback( + pipe_context: CallbackPipeContextType[str], + ) -> None: + nonlocal cts_cb_receive, cts_cb_done + with pipe_context as (sender, receiver): + async for message in receiver: + cts_cb_receive = message + await sender.send(f"{cts_prefix}{message}") + cts_cb_done.set() + + async def server_to_client_callback( + pipe_context: CallbackPipeContextType[str], + ) -> None: + with pipe_context as (sender, receiver): + nonlocal stc_cb_receive, stc_cb_done + async for message in receiver: + stc_cb_receive = message + await sender.send(f"{stc_prefix}{message}") + stc_cb_done.set() + + tool_4_test_fixture = await tool_4_test_fixture_factory( + callback_ws_app_factory_builder( + client_to_server_callback=client_to_server_callback, + server_to_client_callback=server_to_client_callback, + ) + ) + proxy_server_base_url = tool_4_test_fixture.proxy_server_base_url + client_for_conn_to_proxy_server = ( + tool_4_test_fixture.client_for_conn_to_proxy_server + ) + get_request = tool_4_test_fixture.get_request + + async with aconnect_ws( + proxy_server_base_url + "echo_text", client_for_conn_to_proxy_server + ) as ws: + await ws.send_text(msg) + + assert ( + await ws.receive_text() == f"{stc_prefix}{cts_prefix}{msg}" + ), "cts send wrong message" + assert cts_cb_receive == msg, "cts receive wrong message" + assert ( + stc_cb_receive == f"{cts_prefix}{msg}" + ), "cts send wrong message or stc receive wrong message" + + with move_on_after(2) as scope: + await cts_cb_done.wait() + await stc_cb_done.wait() + assert ( + not scope.cancelled_caught + ), "after proxy finished, callback not done or done too late" + + target_starlette_ws = get_request() + assert isinstance(target_starlette_ws, starlette_websockets_module.WebSocket) + # test target ws has disconnected + with pytest.raises(RuntimeError): + await target_starlette_ws.receive_text() + + @pytest.mark.timeout(15) # prevent dead lock + @pytest.mark.anyio() + @pytest.mark.parametrize("forget_to_enter_pipe", [True, False]) + async def test_ws_callback_error( + self, + tool_4_test_fixture_factory: Tool4TestFixtureFactory, + forget_to_enter_pipe: bool, + ) -> None: + """Test reverse websocket proxy with callback that will raise error.""" + msg = "foo" + raise_exception_event = anyio.Event() + + async def client_to_server_callback( + pipe_context: CallbackPipeContextType[str], + ) -> None: + """We do nothing, just raise exception.""" + if not forget_to_enter_pipe: + # NOTE: we must ensure that exit the context manager to close the pipe, + # or will get dead lock. + with pipe_context as (_sender, _receiver): + await raise_exception_event.wait() + raise Exception() + else: + # Here, we forget to enter the pipe context, so the context never be exited. + # Because there are some mitigation measures, we will not get dead lock, + # but we still get warning. + await raise_exception_event.wait() + raise Exception() + + class NullContext: + def __init__(self, *args: object, **kwargs: object) -> None: + pass + + def __enter__(self, *args: object, **kwargs: object) -> None: + pass + + def __exit__(self, *args: object, **kwargs: object) -> None: + pass + + def null_collector() -> None: + pass + + if not forget_to_enter_pipe: + # use `gc.collect` to make sure `__del__` method of pipe context be called, + # so that we can get the warning. + warning_capturer, gc_collector = pytest.warns, gc.collect + else: + warning_capturer, gc_collector = NullContext, null_collector + + with warning_capturer( + RuntimeWarning, + match="You never exit the pipe context, it may cause a deadlock.", + ): + tool_4_test_fixture = await tool_4_test_fixture_factory( + callback_ws_app_factory_builder( + client_to_server_callback=client_to_server_callback, + ) + ) + proxy_server_base_url = tool_4_test_fixture.proxy_server_base_url + client_for_conn_to_proxy_server = ( + tool_4_test_fixture.client_for_conn_to_proxy_server + ) + get_request = tool_4_test_fixture.get_request + + async with aconnect_ws( + proxy_server_base_url + "echo_text", client_for_conn_to_proxy_server + ) as ws: + await ws.send_text(msg) + raise_exception_event.set() + + # client ws has disconnected + with pytest.raises(httpx_ws.WebSocketDisconnect): + await ws.receive_text() + + target_starlette_ws = get_request() + assert isinstance( + target_starlette_ws, starlette_websockets_module.WebSocket + ) + # test target ws has disconnected + with pytest.raises(RuntimeError): + await target_starlette_ws.receive_text() + + gc_collector()