diff --git a/CHANGELOG.md b/CHANGELOG.md
index 336f974..0fb92f7 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -22,6 +22,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
+### Added
+
+- [#41](https://github.com/WSH032/fastapi-proxy-lib/pull/41) - feat: add `callback` api to `WebSocketProxy`. Thanks [@WSH032](https://github.com/WSH032) and [@IvesAwadi](https://github.com/IvesAwadi)!
+
### Changed
- [#30](https://github.com/WSH032/fastapi-proxy-lib/pull/30) - fix(internal): use `websocket` in favor of `websocket_route`. Thanks [@WSH032](https://github.com/WSH032)!
diff --git a/README.md b/README.md
index 56d2f71..5c5da6c 100644
--- a/README.md
+++ b/README.md
@@ -32,8 +32,9 @@ Source Code:
- [x] Support both **reverse** proxy and **forward** proxy.
- [x] **Transparently** and **losslessly** handle all proxy requests,
Including **HTTP headers**, **cookies**, **query parameters**, **body**, etc.
+- [X] WebSocket proxy **callback**.
- [x] Asynchronous streaming transfer, support **file proxy**.
-- [x] `fastapi-proxy-lib` value [privacy security](https://wsh032.github.io/fastapi-proxy-lib/Usage/Security/).
+- [x] `fastapi-proxy-lib` value [**privacy security**](https://wsh032.github.io/fastapi-proxy-lib/Usage/Security/).
### other features
diff --git a/docs/Usage/Advanced.md b/docs/Usage/Advanced.md
index b2e522d..0951399 100644
--- a/docs/Usage/Advanced.md
+++ b/docs/Usage/Advanced.md
@@ -4,7 +4,8 @@ For the following scenarios, you might prefer [fastapi_proxy_lib.core][]:
- When you need to use proxies with **only** `Starlette` dependencies (without `FastAPI`).
- When you need more fine-grained control over parameters and lifespan event.
-- When you need to further process the input and output before and after the proxy (similar to middleware).
+- When you need to further process the input and output before and after the http proxy (similar to `middleware`).
+- When you need `callback` to modify the websocket proxy messages.
We will demonstrate with `FastAPI`,
but you can completely switch to the `Starlette` approach,
@@ -19,13 +20,13 @@ Also (without annotations):
- [`ForwardHttpProxy#examples`][fastapi_proxy_lib.core.http.ForwardHttpProxy--examples]
- [`ReverseWebSocketProxy#examples`][fastapi_proxy_lib.core.websocket.ReverseWebSocketProxy--examples]
-## Modify request
+## Modify HTTP request
In some cases, you may want to make final modifications before sending a request, such as performing behind-the-scenes authentication by modifying the headers of request.
`httpx` provides comprehensive authentication support, and `fastapi-proxy-lib` offers first-class support for `httpx`.
-See
+See
You can refer following example to implement a simple authentication:
@@ -35,7 +36,7 @@ from fastapi_proxy_lib.fastapi.app import reverse_http_app
class MyCustomAuth(httpx.Auth):
- # ref: https://www.python-httpx.org/advanced/#customizing-authentication
+ # ref: https://www.python-httpx.org/advanced/authentication/
def __init__(self, token: str):
self.token = token
@@ -55,7 +56,7 @@ app = reverse_http_app(
visit `/headers` to see the result which contains `"X-Authentication": "bearer_token"` header.
-## Modify response
+## Modify HTTP response
In some cases, you may want to make final modifications before return the response to the client, such as transcoding video response streams.
@@ -118,3 +119,19 @@ async def _(request: Request, path: str = ""):
```
visit `/`, you will notice that the response body is printed to the console.
+
+## Modify WebSocket message
+
+In some cases, you might want to modify the content of the messages that the WebSocket proxy receives and sends to the client and target server.
+
+In version `0.2.0` of `fastapi-proxy-lib`, we introduced a [`callback API`][fastapi_proxy_lib.core.websocket.ReverseWebSocketProxy.proxy] for `WebSocketProxy` to allow you to do this.
+
+See example: [ReverseWebSocketProxy#with-callback][fastapi_proxy_lib.core.websocket.ReverseWebSocketProxy--with-callback]
+
+Also:
+
+- RFC: [#40](https://github.com/WSH032/fastapi-proxy-lib/issues/40)
+- PR: [#41](https://github.com/WSH032/fastapi-proxy-lib/pull/41)
+
+!!!example
+ The current implementation still has some defects. Read the [callback-implementation][fastapi_proxy_lib.core.websocket.BaseWebSocketProxy.send_request_to_target--callback-implementation] section, or you might accidentally shoot yourself in the foot.
diff --git a/docs/Usage/FastAPI-Helper.md b/docs/Usage/FastAPI-Helper.md
index 45f1ebe..9a25600 100644
--- a/docs/Usage/FastAPI-Helper.md
+++ b/docs/Usage/FastAPI-Helper.md
@@ -10,7 +10,7 @@ There are two helper modules to get FastAPI `app`/`router` for proxy convenientl
## app
-use `fastapi_proxy_lib.fastapi.app` is very convenient and out of the box, there are three helper functions:
+`fastapi_proxy_lib.fastapi.app` is very convenient and out of the box, there are three helper functions:
- [forward_http_app][fastapi_proxy_lib.fastapi.app.forward_http_app]
- [reverse_http_app][fastapi_proxy_lib.fastapi.app.reverse_http_app]
@@ -46,3 +46,9 @@ For the following scenarios, you might prefer [fastapi_proxy_lib.fastapi.router]
- When you need to [mount the proxy on a route of larger app](https://fastapi.tiangolo.com/tutorial/bigger-applications/).
**^^[Please refer to the documentation of `RouterHelper` for more information :material-file-document: ][fastapi_proxy_lib.fastapi.router.RouterHelper--examples]^^**.
+
+---
+
+## More
+
+**The `Helper Module` might not meet your further customization needs. Please refer to the [Advanced](Advanced.md) section, which is the core of `fastapi-proxy-lib`, for more personalization options.**
diff --git a/mkdocs.yml b/mkdocs.yml
index d4c2e3a..26884ad 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -108,6 +108,8 @@ plugins:
import:
- https://frankie567.github.io/httpx-ws/objects.inv
- https://fastapi.tiangolo.com/objects.inv
+ - https://anyio.readthedocs.io/en/stable/objects.inv
+ - https://docs.python.org/3/objects.inv
options:
docstring_style: google
paths: [src]
diff --git a/pyproject.toml b/pyproject.toml
index f3c9bbf..f545340 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -52,7 +52,8 @@ dependencies = [
"httpx",
"httpx-ws >= 0.4.2",
"starlette",
- "typing_extensions >=4.5.0",
+ "typing_extensions >= 4.12",
+ "anyio >= 4",
]
[project.optional-dependencies]
@@ -91,13 +92,12 @@ dependencies = [
# NOTE: 👆
# lint-check
- "pyright == 1.1.356", # pyright must be installed in the runtime environment
+ "pyright == 1.1.372", # pyright must be installed in the runtime environment
# test
"pytest == 7.*",
"pytest-cov == 4.*",
"uvicorn[standard] < 1.0.0", # TODO: Once it releases version 1.0.0, we will remove this restriction.
"httpx[http2]", # we don't set version here, instead set it in `[project].dependencies`.
- "anyio", # we don't set version here, because fastapi has a dependency on it
"asgi-lifespan==2.*",
"pytest-timeout==2.*",
]
diff --git a/src/fastapi_proxy_lib/core/http.py b/src/fastapi_proxy_lib/core/http.py
index fab3316..e891502 100644
--- a/src/fastapi_proxy_lib/core/http.py
+++ b/src/fastapi_proxy_lib/core/http.py
@@ -222,7 +222,7 @@ class BaseHttpProxy(BaseProxyModel):
"""
@override
- async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOverride]
+ async def send_request_to_target(
self, *, request: StarletteRequest, target_url: httpx.URL
) -> StarletteResponse:
"""Change request headers and send request to target url.
@@ -318,6 +318,8 @@ class ReverseHttpProxy(BaseHttpProxy):
# # Examples
+ ## Basic usage
+
```python
from contextlib import asynccontextmanager
from typing import AsyncIterator
@@ -341,7 +343,7 @@ async def close_proxy_event(_: FastAPI) -> AsyncIterator[None]: # (1)!
async def _(request: Request, path: str = ""):
return await proxy.proxy(request=request, path=path) # (3)!
- # Then run shell: `uvicorn :app --host http://127.0.0.1:8000 --port 8000`
+ # Then run shell: `uvicorn :app --host 127.0.0.1 --port 8000`
# visit the app: `http://127.0.0.1:8000/`
# you will get the response from `http://www.example.com/`
```
@@ -452,6 +454,8 @@ class ForwardHttpProxy(BaseHttpProxy):
# # Examples
+ ## Basic usage
+
```python
from contextlib import asynccontextmanager
from typing import AsyncIterator
@@ -476,7 +480,7 @@ async def close_proxy_event(_: FastAPI) -> AsyncIterator[None]:
async def _(request: Request, path: str = ""):
return await proxy.proxy(request=request, path=path)
- # Then run shell: `uvicorn :app --host http://127.0.0.1:8000 --port 8000`
+ # Then run shell: `uvicorn :app --host 127.0.0.1 --port 8000`
# visit the app: `http://127.0.0.1:8000/http://www.example.com`
# you will get the response from `http://www.example.com`
```
diff --git a/src/fastapi_proxy_lib/core/websocket.py b/src/fastapi_proxy_lib/core/websocket.py
index 36dacd7..5c343a2 100644
--- a/src/fastapi_proxy_lib/core/websocket.py
+++ b/src/fastapi_proxy_lib/core/websocket.py
@@ -2,21 +2,36 @@
import asyncio
import logging
-from contextlib import AsyncExitStack
+import warnings
+from contextlib import (
+ AbstractContextManager,
+ AsyncExitStack,
+ ExitStack,
+ asynccontextmanager,
+ nullcontext,
+)
from typing import (
TYPE_CHECKING,
Any,
+ AsyncIterator,
+ Callable,
+ ContextManager,
+ Coroutine,
+ Generic,
List,
Literal,
NamedTuple,
NoReturn,
Optional,
+ Tuple,
Union,
)
import httpx
import httpx_ws
import starlette.websockets as starlette_ws
+from anyio import create_memory_object_stream
+from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_ws._api import ( # HACK: 注意,这个是私有模块
DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS,
DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS,
@@ -29,6 +44,7 @@
from starlette.responses import StreamingResponse
from starlette.types import Scope
from typing_extensions import TypeAlias, override
+from typing_extensions import TypeVar as TypeVarExt
from wsproto.events import BytesMessage as WsprotoBytesMessage
from wsproto.events import TextMessage as WsprotoTextMessage
@@ -39,10 +55,7 @@
check_http_version,
)
-__all__ = (
- "BaseWebSocketProxy",
- "ReverseWebSocketProxy",
-)
+__all__ = ("BaseWebSocketProxy", "ReverseWebSocketProxy", "CallbackPipeContextType")
if TYPE_CHECKING:
# 这些是私有模块,无法确定以后版本是否会改变,为了保证运行时不会出错,我们使用TYPE_CHECKING
@@ -56,6 +69,39 @@
_ServerToClientTaskType: TypeAlias = "asyncio.Task[httpx_ws.WebSocketDisconnect]"
+_DefaultWsMsgType = Union[str, bytes]
+_WsMsgType = Union[str, bytes, _DefaultWsMsgType]
+"""Websocket message type."""
+_WsMsgTypeVar = TypeVarExt("_WsMsgTypeVar", bound=_WsMsgType, default=_DefaultWsMsgType)
+"""Generic websocket message type."""
+_CallbackPipeType = Tuple[
+ MemoryObjectSendStream[_WsMsgTypeVar], MemoryObjectReceiveStream[_WsMsgTypeVar]
+]
+"""Send end and receive end of a callback pipe."""
+CallbackPipeContextType = ContextManager[_CallbackPipeType[_WsMsgTypeVar]]
+"""A context manager that will automatically close the pipe when exit.
+
+The `__enter__` method will return one end of each pair of
+[`memory-object-streams`](https://anyio.readthedocs.io/en/stable/streams.html#memory-object-streams) pipelines.
+
+See example: [ReverseWebSocketProxy#with-callback][fastapi_proxy_lib.core.websocket.ReverseWebSocketProxy--with-callback]
+
+Warning:
+ This is a unstable public type hint, you shouldn't rely on it.
+ You should create your own type hint instead.
+
+Note:
+ You must ensure that **exit** the context manager,
+ or maybe you will get a **deadlock**.
+
+ See: [`callback-implementation`][fastapi_proxy_lib.core.websocket.BaseWebSocketProxy.send_request_to_target--callback-implementation]
+"""
+_CallbackType = Callable[
+ [CallbackPipeContextType[_WsMsgTypeVar]], Coroutine[None, None, None]
+]
+"""The websocket callback provided by user."""
+
+
class _ClientServerProxyTask(NamedTuple):
"""The task group for passing websocket message between client and target server."""
@@ -70,6 +116,8 @@ class _ClientServerProxyTask(NamedTuple):
SUPPORTED_WS_HTTP_VERSIONS = ("1.1",)
"""The http versions that we supported now. It depends on `httpx`."""
+_CALLBACK_BUFFER_SIZE = 0
+"""The buffer size of the callback pipe."""
#################### Error ####################
@@ -97,6 +145,84 @@ def _get_client_request_subprotocols(ws_scope: Scope) -> Union[List[str], None]:
return subprotocols
+class _PipeContextBuilder(
+ # before py3.9, `AbstractContextManager` is not generic
+ AbstractContextManager, # pyright: ignore[reportMissingTypeArgument]
+ Generic[_WsMsgTypeVar],
+):
+ """Auto close the pipe when exit the context."""
+
+ def __init__(self, raw_pipe: _CallbackPipeType[_WsMsgTypeVar]) -> None:
+ self._raw_pipe = raw_pipe
+ self._exit_stack = ExitStack()
+ self._exited = False
+
+ def __enter__(self) -> _CallbackPipeType[_WsMsgTypeVar]:
+ sender, receiver = self._raw_pipe
+ self._exit_stack.enter_context(sender)
+ self._exit_stack.enter_context(receiver)
+ return self._raw_pipe
+
+ def __exit__(self, *_: Any) -> None:
+ self._exit_stack.__exit__(*_)
+ self._exited = True
+
+ def __del__(self):
+ if not self._exited:
+ warnings.warn(
+ "You never exit the pipe context, it may cause a deadlock.",
+ RuntimeWarning,
+ stacklevel=1,
+ )
+
+
+@asynccontextmanager
+async def _wait_task_and_ignore_exce(
+ callback_task: "asyncio.Task[Any]",
+) -> AsyncIterator[None]:
+ """Wait for the task when exiting, but ignore the exception."""
+ yield
+ try:
+ await callback_task
+ except Exception:
+ pass
+
+
+async def _enable_callback(
+ callback: _CallbackType[_WsMsgTypeVar], task_name: str, exit_stack: AsyncExitStack
+) -> CallbackPipeContextType[_WsMsgTypeVar]:
+ """Create a task to run the callback.
+
+ - The callback task will be awaited when exit the `exit_stack`,
+ but the exception of callback task will be ignored.
+ - When the callback done(normal or exception),
+ the pipe used in the callback will be closed;
+ this is for preventing callback forgetting to close the pipe and causing deadlock.
+ ```py
+ async def callback(ctx: CallbackPipeContextType[str]) -> None:
+ pass
+ ```
+ NOTE: This is just a mitigation measure and may be removed in the future,
+ so it needs to be documented in the public documentation.
+ """
+ proxy_sender, cb_receiver = create_memory_object_stream[_WsMsgTypeVar](
+ _CALLBACK_BUFFER_SIZE
+ )
+ cb_sender, proxy_receiver = create_memory_object_stream[_WsMsgTypeVar](
+ _CALLBACK_BUFFER_SIZE
+ )
+ cb_pipe_ctx = _PipeContextBuilder((cb_sender, cb_receiver))
+ proxy_pipe_ctx = _PipeContextBuilder((proxy_sender, proxy_receiver))
+
+ cb_task = asyncio.create_task(callback(cb_pipe_ctx), name=task_name)
+ # use `done_callback` to close the pipe when the callback task is done,
+ # NOTE: we close `sender` and `receiver` directly, instead of using `cb_pipe_ctx.__exit__`,
+ # so that `_PipeContextBuilder` issue a warning when the callback forgets to close the pipe.
+ cb_task.add_done_callback(lambda _: (cb_sender.close(), cb_receiver.close()))
+ await exit_stack.enter_async_context(_wait_task_and_ignore_exce(cb_task))
+ return proxy_pipe_ctx
+
+
# TODO: 等待starlette官方的支持
# 为什么使用这个函数而不是直接使用starlette_WebSocket.receive_text()
# 请看: https://github.com/encode/starlette/discussions/2310
@@ -261,13 +387,19 @@ async def _starlette_ws_send_bytes_or_str(
async def _wait_client_then_send_to_server(
- *, client_ws: starlette_ws.WebSocket, server_ws: httpx_ws.AsyncWebSocketSession
+ *,
+ client_ws: starlette_ws.WebSocket,
+ server_ws: httpx_ws.AsyncWebSocketSession,
+ pipe_context: Optional[CallbackPipeContextType[_WsMsgTypeVar]] = None,
) -> starlette_ws.WebSocketDisconnect:
"""Receive data from client, then send to target server.
Args:
client_ws: The websocket which receive data of client.
server_ws: The websocket which send data to target server.
+ pipe_context: The callback pipe for processing data.
+ will send the received data(from client) to the sender,
+ and receive the data from the receiver(then send to the server).
Returns:
If the client_ws sends a shutdown message normally, will return starlette_ws.WebSocketDisconnect.
@@ -275,24 +407,39 @@ async def _wait_client_then_send_to_server(
Raises:
error for receiving: refer to `_starlette_ws_receive_bytes_or_str`
error for sending: refer to `_httpx_ws_send_bytes_or_str`
+ error for callback: refer to `MemoryObjectReceiveStream.receive` and `MemoryObjectSendStream.send`
"""
- while True:
- try:
- receive = await _starlette_ws_receive_bytes_or_str(client_ws)
- except starlette_ws.WebSocketDisconnect as e:
- return e
- else:
+ with pipe_context or nullcontext() as pipe:
+ while True:
+ try:
+ receive = await _starlette_ws_receive_bytes_or_str(client_ws)
+ except starlette_ws.WebSocketDisconnect as e:
+ return e
+
+ # TODO: do not use `if` statement in loop
+ if pipe is not None:
+ sender, receiver = pipe
+ # XXX, HACK, TODO: We can't identify the msg type from websocket,
+ # so we have to igonre the type check here.
+ await sender.send(receive) # pyright: ignore [reportArgumentType]
+ receive = await receiver.receive()
await _httpx_ws_send_bytes_or_str(server_ws, receive)
async def _wait_server_then_send_to_client(
- *, client_ws: starlette_ws.WebSocket, server_ws: httpx_ws.AsyncWebSocketSession
+ *,
+ client_ws: starlette_ws.WebSocket,
+ server_ws: httpx_ws.AsyncWebSocketSession,
+ pipe_context: Optional[CallbackPipeContextType[_WsMsgTypeVar]] = None,
) -> httpx_ws.WebSocketDisconnect:
"""Receive data from target server, then send to client.
Args:
client_ws: The websocket which send data to client.
server_ws: The websocket which receive data of target server.
+ pipe_context: The callback pipe for processing data.
+ will send the received data(from server) to the sender,
+ and receive the data from the receiver(then send to the client).
Returns:
If the server_ws sends a shutdown message normally, will return httpx_ws.WebSocketDisconnect.
@@ -300,13 +447,22 @@ async def _wait_server_then_send_to_client(
Raises:
error for receiving: refer to `_httpx_ws_receive_bytes_or_str`
error for sending: refer to `_starlette_ws_send_bytes_or_str`
+ error for callback: refer to `MemoryObjectReceiveStream.receive` and `MemoryObjectSendStream.send`
"""
- while True:
- try:
- receive = await _httpx_ws_receive_bytes_or_str(server_ws)
- except httpx_ws.WebSocketDisconnect as e:
- return e
- else:
+ with pipe_context or nullcontext() as pipe:
+ while True:
+ try:
+ receive = await _httpx_ws_receive_bytes_or_str(server_ws)
+ except httpx_ws.WebSocketDisconnect as e:
+ return e
+
+ # TODO: do not use `if` statement in loop
+ if pipe is not None:
+ sender, receiver = pipe
+ # XXX, HACK, TODO: We can't identify the msg type from websocket,
+ # so we have to igonre the type check here.
+ await sender.send(receive) # pyright: ignore [reportArgumentType]
+ receive = await receiver.receive()
await _starlette_ws_send_bytes_or_str(client_ws, receive)
@@ -396,6 +552,14 @@ async def _close_ws(
#################### # ####################
+_WsMsgTypeVar_CTS = TypeVarExt(
+ "_WsMsgTypeVar_CTS", bound=_WsMsgType, default=_DefaultWsMsgType
+)
+_WsMsgTypeVar_STC = TypeVarExt(
+ "_WsMsgTypeVar_STC", bound=_WsMsgType, default=_DefaultWsMsgType
+)
+
+
class BaseWebSocketProxy(BaseProxyModel):
"""Websocket proxy base class.
@@ -408,7 +572,7 @@ class BaseWebSocketProxy(BaseProxyModel):
keepalive_ping_timeout_seconds: refer to [httpx_ws.aconnect_ws][]
Tip:
- [`httpx_ws.aconnect_ws`](https://frankie567.github.io/httpx-ws/reference/httpx_ws/#httpx_ws.aconnect_ws)
+ [httpx_ws.aconnect_ws][]
"""
client: httpx.AsyncClient
@@ -447,7 +611,7 @@ def __init__(
keepalive_ping_timeout_seconds: refer to [httpx_ws.aconnect_ws][]
Tip:
- [`httpx_ws.aconnect_ws`](https://frankie567.github.io/httpx-ws/reference/httpx_ws/#httpx_ws.aconnect_ws)
+ [httpx_ws.aconnect_ws][]
"""
self.max_message_size_bytes = max_message_size_bytes
self.queue_size = queue_size
@@ -456,11 +620,13 @@ def __init__(
super().__init__(client, follow_redirects=follow_redirects)
@override
- async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOverride]
+ async def send_request_to_target(
self,
*,
websocket: starlette_ws.WebSocket,
target_url: httpx.URL,
+ client_to_server_callback: Optional[_CallbackType[_WsMsgTypeVar_CTS]] = None,
+ server_to_client_callback: Optional[_CallbackType[_WsMsgTypeVar_STC]] = None,
) -> Union[Literal[False], StarletteResponse]:
"""Establish websocket connection for both client and target_url, then pass messages between them.
@@ -469,6 +635,76 @@ async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOv
Args:
websocket: The client websocket requests.
target_url: The url of target websocket server.
+ client_to_server_callback: The callback function for processing data from client to server.
+ The usage example is in subclass [ReverseWebSocketProxy#with-callback][fastapi_proxy_lib.core.websocket.ReverseWebSocketProxy--with-callback].
+ server_to_client_callback: The callback function for processing data from server to client.
+ The usage example is in subclass [ReverseWebSocketProxy#with-callback][fastapi_proxy_lib.core.websocket.ReverseWebSocketProxy--with-callback].
+
+ ## Callback implementation
+
+ Note: The `callback` implementation details:
+ - If `callback` is not None, will create a new task to run the callback function.
+ When the whole proxy task finishes, the callback task will be awaited, but the exception will be ignored.
+ - `callback` must ensure that it closes the pipe(exit the pipe context) when it finishes or encounters an exception,
+ or you will get a **deadlock**.
+ - A common mistake is the `callback` encountering an exception or returning before entering the context.
+ ```py
+ async def callback(ctx: CallbackPipeContextType[str]) -> None:
+ # mistake: not entering the context
+ return
+
+ async def callback(ctx: CallbackPipeContextType[str]) -> None:
+ # mistake: encountering an exception before entering the context
+ 1 / 0
+ with ctx as (sender, receiver):
+ pass
+ ```
+ - If the callback-side pipe is closed but proxy task is still running,
+ `proxy` will treat it as a exception and close websocket connection.
+ - If `proxy` encounters an exception or receives a disconnection request, the proxy-side pipe will be closed,
+ then the callback-side pipe will receive an exception
+ (refer to [send][anyio.streams.memory.MemoryObjectSendStream.send],
+ also [receive][anyio.streams.memory.MemoryObjectReceiveStream.receive]).
+ - **The buffer size of the pipe is currently 0 (this may change in the future)**,
+ which means that if the `callback` does not call the receiver, the `WebSocket` will block.
+ Therefore, the `callback` should not take too long to process a single message.
+ If you expect to be unable to call the receiver for an extended period,
+ you need to create your own buffer to store messages.
+
+ See also:
+
+ - [RFC#40](https://github.com/WSH032/fastapi-proxy-lib/issues/40)
+ - [memory-object-streams](https://anyio.readthedocs.io/en/stable/streams.html#memory-object-streams)
+
+ Bug: Dead lock
+ The current implementation only supports a strict `one-receive-one-send` mode within a single loop.
+ If this pattern is violated, such as `multiple receives and one send`, `one receive and multiple sends`,
+ or `sending before receiving` within a single loop, it will result in a deadlock.
+
+ See Issue Tracker: [#42](https://github.com/WSH032/fastapi-proxy-lib/issues/42)
+
+ ```py
+ async def callback(ctx: CallbackPipeContextType[str]) -> None:
+ with ctx as (sender, receiver):
+ # multiple receives and one send, dead lock!
+ await receiver.receive()
+ await receiver.receive()
+ await sender.send("foo")
+
+ async def callback(ctx: CallbackPipeContextType[str]) -> None:
+ with ctx as (sender, receiver):
+ # one receive and multiple sends, dead lock!
+ async for message in receiver:
+ await sender.send("foo")
+ await sender.send("bar")
+
+ async def callback(ctx: CallbackPipeContextType[str]) -> None:
+ with ctx as (sender, receiver):
+ # sending before receiving, dead lock!
+ await sender.send("foo")
+ async for message in receiver:
+ await sender.send(message)
+ ```
Returns:
If the establish websocket connection unsuccessfully:
@@ -559,6 +795,9 @@ async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOv
# NOTE: 对于反向代理服务器,我们不返回 "任何" "具体的内部" 错误信息给客户端,因为这可能涉及到服务器内部的信息泄露
+ # NOTE: Do not exit the `stack` before return `StreamingResponse` above.
+ # Because once the `stack` close, the `httpx_ws` connection will be closed,
+ # then the streaming response will encounter an error.
# NOTE: 请使用 with 语句来 "保证关闭" AsyncWebSocketSession
async with stack:
# TODO: websocket.accept 中还有一个headers参数,但是httpx_ws不支持,考虑发起PR
@@ -573,17 +812,34 @@ async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOv
# headers=...
)
+ cts_pipe_ctx = (
+ await _enable_callback(
+ client_to_server_callback, "client_to_server_callback", stack
+ )
+ if client_to_server_callback is not None
+ else None
+ )
client_to_server_task = asyncio.create_task(
_wait_client_then_send_to_server(
client_ws=websocket,
server_ws=proxy_ws,
+ pipe_context=cts_pipe_ctx,
),
name="client_to_server_task",
)
+
+ stc_pipe_ctx = (
+ await _enable_callback(
+ server_to_client_callback, "server_to_client_callback", stack
+ )
+ if server_to_client_callback is not None
+ else None
+ )
server_to_client_task = asyncio.create_task(
_wait_server_then_send_to_client(
client_ws=websocket,
server_ws=proxy_ws,
+ pipe_context=stc_pipe_ctx,
),
name="server_to_client_task",
)
@@ -610,6 +866,8 @@ async def send_request_to_target( # pyright: ignore [reportIncompatibleMethodOv
# 因为第二种情况的存在,所以需要用 wait_for 强制让其退出
# 但考虑到第一种情况,先等它 1s ,看看能否正常退出
try:
+ # Here we just cancel the task that is not finished,
+ # but we don't handle websocket closing here.
_, pending = await asyncio.wait(
task_group,
return_when=asyncio.FIRST_COMPLETED,
@@ -666,7 +924,7 @@ class ReverseWebSocketProxy(BaseWebSocketProxy):
keepalive_ping_timeout_seconds: refer to [httpx_ws.aconnect_ws][]
Tip:
- [`httpx_ws.aconnect_ws`](https://frankie567.github.io/httpx-ws/reference/httpx_ws/#httpx_ws.aconnect_ws)
+ [httpx_ws.aconnect_ws][]
Bug: There is a issue for handshake response:
This WebSocket proxy can correctly forward request headers.
@@ -681,6 +939,8 @@ class ReverseWebSocketProxy(BaseWebSocketProxy):
# # Examples
+ ## Basic usage
+
```python
from contextlib import asynccontextmanager
from typing import AsyncIterator
@@ -704,10 +964,58 @@ async def close_proxy_event(_: FastAPI) -> AsyncIterator[None]:
async def _(websocket: WebSocket, path: str = ""):
return await proxy.proxy(websocket=websocket, path=path)
- # Then run shell: `uvicorn :app --host http://127.0.0.1:8000 --port 8000`
+ # Then run shell: `uvicorn :app --host 127.0.0.1 --port 8000`
# visit the app: `ws://127.0.0.1:8000/`
# you can establish websocket connection with `ws://echo.websocket.events`
```
+
+ ## With callback
+
+ Tip:
+ See also: [`callback-implementation`][fastapi_proxy_lib.core.websocket.BaseWebSocketProxy.send_request_to_target--callback-implementation]
+
+ ```python
+ # NOTE: `CallbackPipeContextType` is a unstable public type hint,
+ # you shouldn't rely on it.
+ # You should create your own type hint instead.
+ from fastapi_proxy_lib.core.websocket import CallbackPipeContextType
+
+ # NOTE: Providing a specific type annotation for `CallbackPipeContextType`
+ # does not offer any runtime guarantees.
+ # This means that even if you annotate it as `[str]`,
+ # you may still receive `bytes` types,
+ # unless you are certain that the other end will only send `str`.
+ # The default generic is `[str | bytes]`.
+
+ async def client_to_server_callback(pipe_context: CallbackPipeContextType[str]) -> None:
+ with pipe_context as (sender, receiver):
+ async for message in receiver:
+ print(f"Received from client: {message}")
+ # here we modify the message with `CTS:` prefix
+ await sender.send(f"CTS:{message}")
+ # If `proxy` receives a disconnection request and websocket closed correctly,
+ # this message will be printed. Or not, if exception occurs.
+ print("client_to_server_callback end")
+
+ # we give a `bytes` type hint here, but it has no runtime guarantee,
+ # just make type-checker happy
+ async def server_to_client_callback(pipe_context: CallbackPipeContextType[bytes]) -> None:
+ with pipe_context as (sender, receiver):
+ async for message in receiver:
+ print(f"Received from server: {message}")
+ await sender.send(f"STC:{message}".encode()) # `bytes` here
+ print("server_to_client_callback end")
+
+ @app.websocket("/{path:path}")
+ async def _(websocket: WebSocket, path: str = ""):
+ return await proxy.proxy(
+ websocket=websocket,
+ path=path,
+ # register callback
+ client_to_server_callback=client_to_server_callback,
+ server_to_client_callback=server_to_client_callback,
+ )
+ ```
'''
client: httpx.AsyncClient
@@ -753,7 +1061,7 @@ def __init__(
keepalive_ping_timeout_seconds: refer to [httpx_ws.aconnect_ws][]
Tip:
- [`httpx_ws.aconnect_ws`](https://frankie567.github.io/httpx-ws/reference/httpx_ws/#httpx_ws.aconnect_ws)
+ [httpx_ws.aconnect_ws][]
"""
self.base_url = check_base_url(base_url)
super().__init__(
@@ -767,7 +1075,12 @@ def __init__(
@override
async def proxy( # pyright: ignore [reportIncompatibleMethodOverride]
- self, *, websocket: starlette_ws.WebSocket, path: Optional[str] = None
+ self,
+ *,
+ websocket: starlette_ws.WebSocket,
+ path: Optional[str] = None,
+ client_to_server_callback: Optional[_CallbackType[_WsMsgTypeVar_CTS]] = None,
+ server_to_client_callback: Optional[_CallbackType[_WsMsgTypeVar_STC]] = None,
) -> Union[Literal[False], StarletteResponse]:
"""Establish websocket connection for both client and target_url, then pass messages between them.
@@ -776,6 +1089,12 @@ async def proxy( # pyright: ignore [reportIncompatibleMethodOverride]
path: The path params of websocket request, which means the path params of base url.
If None, will get it from `websocket.path_params`.
**Usually, you don't need to pass this argument**.
+ client_to_server_callback: The callback function for processing data from client to server.
+ The implementation details are in the base class
+ [`BaseWebSocketProxy.send_request_to_target`][fastapi_proxy_lib.core.websocket.BaseWebSocketProxy.send_request_to_target--callback-implementation].
+ server_to_client_callback: The callback function for processing data from server to client.
+ The implementation details are in the base class
+ [`BaseWebSocketProxy.send_request_to_target`][fastapi_proxy_lib.core.websocket.BaseWebSocketProxy.send_request_to_target--callback-implementation].
Returns:
If the establish websocket connection unsuccessfully:
@@ -800,5 +1119,8 @@ async def proxy( # pyright: ignore [reportIncompatibleMethodOverride]
# self.send_request_to_target 内部会处理连接失败时,返回错误给客户端,所以这里不处理了
return await self.send_request_to_target(
- websocket=websocket, target_url=target_url
+ websocket=websocket,
+ target_url=target_url,
+ client_to_server_callback=client_to_server_callback,
+ server_to_client_callback=server_to_client_callback,
)
diff --git a/tests/test_docs_examples.py b/tests/test_docs_examples.py
index 12bea06..91cbf91 100644
--- a/tests/test_docs_examples.py
+++ b/tests/test_docs_examples.py
@@ -26,7 +26,7 @@ async def close_proxy_event(_: FastAPI) -> AsyncIterator[None]:
async def _(request: Request, path: str = ""):
return await proxy.proxy(request=request, path=path)
- # Then run shell: `uvicorn :app --host http://127.0.0.1:8000 --port 8000`
+ # Then run shell: `uvicorn :app --host 127.0.0.1 --port 8000`
# visit the app: `http://127.0.0.1:8000/http://www.example.com`
# you will get the response from `http://www.example.com`
@@ -55,7 +55,7 @@ async def close_proxy_event(_: FastAPI) -> AsyncIterator[None]: # (1)!
async def _(request: Request, path: str = ""):
return await proxy.proxy(request=request, path=path) # (3)!
- # Then run shell: `uvicorn :app --host http://127.0.0.1:8000 --port 8000`
+ # Then run shell: `uvicorn :app --host 127.0.0.1 --port 8000`
# visit the app: `http://127.0.0.1:8000/`
# you will get the response from `http://www.example.com/`
@@ -93,7 +93,7 @@ async def close_proxy_event(_: FastAPI) -> AsyncIterator[None]:
async def _(websocket: WebSocket, path: str = ""):
return await proxy.proxy(websocket=websocket, path=path)
- # Then run shell: `uvicorn :app --host http://127.0.0.1:8000 --port 8000`
+ # Then run shell: `uvicorn :app --host 127.0.0.1 --port 8000`
# visit the app: `ws://127.0.0.1:8000/`
# you can establish websocket connection with `ws://echo.websocket.events`
diff --git a/tests/test_ws.py b/tests/test_ws.py
index 2119719..3a8e464 100644
--- a/tests/test_ws.py
+++ b/tests/test_ws.py
@@ -1,18 +1,30 @@
# noqa: D100
-
import asyncio
-from contextlib import AsyncExitStack
+import gc
+from contextlib import AsyncExitStack, asynccontextmanager
from multiprocessing import Process, Queue
-from typing import Any, Dict, Literal, Optional
+from typing import Any, AsyncIterator, Dict, Literal, Optional, Protocol
+import anyio
import httpx
import httpx_ws
import pytest
import uvicorn
+from anyio import move_on_after
+from fastapi import FastAPI
+from fastapi_proxy_lib.core.websocket import (
+ CallbackPipeContextType,
+ ReverseWebSocketProxy,
+ # from this project, so it's ok
+ _CallbackType, # pyright: ignore[reportPrivateUsage]
+ _WsMsgTypeVar_CTS, # pyright: ignore[reportPrivateUsage]
+ _WsMsgTypeVar_STC, # pyright: ignore[reportPrivateUsage]
+)
from fastapi_proxy_lib.fastapi.app import reverse_ws_app as get_reverse_ws_app
from httpx_ws import aconnect_ws
from starlette import websockets as starlette_websockets_module
+from starlette.websockets import WebSocket as StarletteWebSocket
from typing_extensions import override
from .app.echo_ws_app import get_app as get_ws_test_app
@@ -37,6 +49,22 @@
NO_PROXIES: Dict[Any, Any] = {"all://": None}
+class WsAppFactory(Protocol):
+ """The factory for `Tool4TestFixture`."""
+
+ def __call__(self, client: httpx.AsyncClient, *, base_url: str) -> FastAPI:
+ """Return the ws proxy app for testing."""
+ ...
+
+
+class Tool4TestFixtureFactory(Protocol):
+ """The factory for `Tool4TestFixture`."""
+
+ async def __call__(self, ws_app_factory: WsAppFactory) -> Tool4TestFixture:
+ """See the implementation for details."""
+ ...
+
+
def _subprocess_run_echo_ws_uvicorn_server(queue: "Queue[str]", **kwargs: Any):
"""Run echo ws app in subprocess.
@@ -101,54 +129,109 @@ async def run():
asyncio.run(run())
+def callback_ws_app_factory_builder(
+ client_to_server_callback: Optional[_CallbackType[_WsMsgTypeVar_CTS]] = None,
+ server_to_client_callback: Optional[_CallbackType[_WsMsgTypeVar_STC]] = None,
+) -> WsAppFactory:
+ """Return a ws proxy app factory with callback."""
+
+ def callback_ws_app_factory(client: httpx.AsyncClient, *, base_url: str) -> FastAPI:
+ """Return a ws proxy app with callback."""
+ proxy = ReverseWebSocketProxy(
+ client=client,
+ base_url=base_url,
+ )
+
+ @asynccontextmanager
+ async def close_proxy_event(_: FastAPI) -> AsyncIterator[None]:
+ """Close proxy."""
+ yield
+ await proxy.aclose()
+
+ app = FastAPI(lifespan=close_proxy_event)
+
+ @app.websocket("/{path:path}")
+ async def _(websocket: StarletteWebSocket, path: str = ""):
+ return await proxy.proxy(
+ websocket=websocket,
+ path=path,
+ client_to_server_callback=client_to_server_callback,
+ server_to_client_callback=server_to_client_callback,
+ )
+
+ return app
+
+ return callback_ws_app_factory
+
+
class TestReverseWsProxy(AbstractTestProxy):
"""For testing reverse websocket proxy."""
- @override
@pytest.fixture(params=WS_BACKENDS_NEED_BE_TESTED)
- async def tool_4_test_fixture( # pyright: ignore[reportIncompatibleMethodOverride]
+ def tool_4_test_fixture_factory(
self,
uvicorn_server_fixture: UvicornServerFixture,
request: pytest.FixtureRequest,
- ) -> Tool4TestFixture:
+ ) -> Tool4TestFixtureFactory:
"""目标服务器请参考`tests.app.echo_ws_app.get_app`."""
- echo_ws_test_model = get_ws_test_app()
- echo_ws_app = echo_ws_test_model.app
- echo_ws_get_request = echo_ws_test_model.get_request
- target_ws_server = await uvicorn_server_fixture(
- uvicorn.Config(
- echo_ws_app, port=DEFAULT_PORT, host=DEFAULT_HOST, ws=request.param
- ),
- contx_exit_timeout=DEFAULT_CONTX_EXIT_TIMEOUT,
- )
+ async def _tool_4_test_fixture_factory(
+ ws_app_factory: WsAppFactory,
+ ) -> Tool4TestFixture:
+ """Create a tool for testing reverse websocket proxy.
+
+ Args:
+ ws_app_factory: A app factory for create reverse proxy websocket app,
+ """
+ echo_ws_test_model = get_ws_test_app()
+ echo_ws_app = echo_ws_test_model.app
+ echo_ws_get_request = echo_ws_test_model.get_request
+
+ target_ws_server = await uvicorn_server_fixture(
+ uvicorn.Config(
+ echo_ws_app, port=DEFAULT_PORT, host=DEFAULT_HOST, ws=request.param
+ ),
+ contx_exit_timeout=DEFAULT_CONTX_EXIT_TIMEOUT,
+ )
- target_server_base_url = str(target_ws_server.contx_socket_url)
+ target_server_base_url = str(target_ws_server.contx_socket_url)
- client_for_conn_to_target_server = httpx.AsyncClient(proxies=NO_PROXIES)
+ client_for_conn_to_target_server = httpx.AsyncClient(proxies=NO_PROXIES)
- reverse_ws_app = get_reverse_ws_app(
- client=client_for_conn_to_target_server, base_url=target_server_base_url
- )
+ reverse_ws_app = ws_app_factory(
+ client=client_for_conn_to_target_server, base_url=target_server_base_url
+ )
- proxy_ws_server = await uvicorn_server_fixture(
- uvicorn.Config(
- reverse_ws_app, port=DEFAULT_PORT, host=DEFAULT_HOST, ws=request.param
- ),
- contx_exit_timeout=DEFAULT_CONTX_EXIT_TIMEOUT,
- )
+ proxy_ws_server = await uvicorn_server_fixture(
+ uvicorn.Config(
+ reverse_ws_app,
+ port=DEFAULT_PORT,
+ host=DEFAULT_HOST,
+ ws=request.param,
+ ),
+ contx_exit_timeout=DEFAULT_CONTX_EXIT_TIMEOUT,
+ )
- proxy_server_base_url = str(proxy_ws_server.contx_socket_url)
+ proxy_server_base_url = str(proxy_ws_server.contx_socket_url)
- client_for_conn_to_proxy_server = httpx.AsyncClient(proxies=NO_PROXIES)
+ client_for_conn_to_proxy_server = httpx.AsyncClient(proxies=NO_PROXIES)
- return Tool4TestFixture(
- client_for_conn_to_target_server=client_for_conn_to_target_server,
- client_for_conn_to_proxy_server=client_for_conn_to_proxy_server,
- get_request=echo_ws_get_request,
- target_server_base_url=target_server_base_url,
- proxy_server_base_url=proxy_server_base_url,
- )
+ return Tool4TestFixture(
+ client_for_conn_to_target_server=client_for_conn_to_target_server,
+ client_for_conn_to_proxy_server=client_for_conn_to_proxy_server,
+ get_request=echo_ws_get_request,
+ target_server_base_url=target_server_base_url,
+ proxy_server_base_url=proxy_server_base_url,
+ )
+
+ return _tool_4_test_fixture_factory
+
+ @override
+ @pytest.fixture()
+ async def tool_4_test_fixture( # pyright: ignore[reportIncompatibleMethodOverride]
+ self, tool_4_test_fixture_factory: Tool4TestFixtureFactory
+ ) -> Tool4TestFixture:
+ return await tool_4_test_fixture_factory(get_reverse_ws_app)
@pytest.mark.anyio()
async def test_ws_proxy(self, tool_4_test_fixture: Tool4TestFixture) -> None:
@@ -313,3 +396,160 @@ async def test_target_server_shutdown_abnormally(
# 只要第二个客户端不是在之前40s基础上又重复40s,就暂时没问题,
# 因为这模拟了多个客户端进行连接的情况。
assert (seconde_ws_recv_end - seconde_ws_recv_start) < 2
+
+ @pytest.mark.timeout(15) # prevent dead lock
+ @pytest.mark.anyio()
+ async def test_ws_proxy_with_callback(
+ self, tool_4_test_fixture_factory: Tool4TestFixtureFactory
+ ) -> None:
+ """Test reverse websocket proxy with callback."""
+ msg = "foo"
+ cts_prefix = "CTS:"
+ stc_prefix = "STC:"
+
+ cts_cb_receive = None
+ stc_cb_receive = None
+ cts_cb_done = anyio.Event()
+ stc_cb_done = anyio.Event()
+
+ async def client_to_server_callback(
+ pipe_context: CallbackPipeContextType[str],
+ ) -> None:
+ nonlocal cts_cb_receive, cts_cb_done
+ with pipe_context as (sender, receiver):
+ async for message in receiver:
+ cts_cb_receive = message
+ await sender.send(f"{cts_prefix}{message}")
+ cts_cb_done.set()
+
+ async def server_to_client_callback(
+ pipe_context: CallbackPipeContextType[str],
+ ) -> None:
+ with pipe_context as (sender, receiver):
+ nonlocal stc_cb_receive, stc_cb_done
+ async for message in receiver:
+ stc_cb_receive = message
+ await sender.send(f"{stc_prefix}{message}")
+ stc_cb_done.set()
+
+ tool_4_test_fixture = await tool_4_test_fixture_factory(
+ callback_ws_app_factory_builder(
+ client_to_server_callback=client_to_server_callback,
+ server_to_client_callback=server_to_client_callback,
+ )
+ )
+ proxy_server_base_url = tool_4_test_fixture.proxy_server_base_url
+ client_for_conn_to_proxy_server = (
+ tool_4_test_fixture.client_for_conn_to_proxy_server
+ )
+ get_request = tool_4_test_fixture.get_request
+
+ async with aconnect_ws(
+ proxy_server_base_url + "echo_text", client_for_conn_to_proxy_server
+ ) as ws:
+ await ws.send_text(msg)
+
+ assert (
+ await ws.receive_text() == f"{stc_prefix}{cts_prefix}{msg}"
+ ), "cts send wrong message"
+ assert cts_cb_receive == msg, "cts receive wrong message"
+ assert (
+ stc_cb_receive == f"{cts_prefix}{msg}"
+ ), "cts send wrong message or stc receive wrong message"
+
+ with move_on_after(2) as scope:
+ await cts_cb_done.wait()
+ await stc_cb_done.wait()
+ assert (
+ not scope.cancelled_caught
+ ), "after proxy finished, callback not done or done too late"
+
+ target_starlette_ws = get_request()
+ assert isinstance(target_starlette_ws, starlette_websockets_module.WebSocket)
+ # test target ws has disconnected
+ with pytest.raises(RuntimeError):
+ await target_starlette_ws.receive_text()
+
+ @pytest.mark.timeout(15) # prevent dead lock
+ @pytest.mark.anyio()
+ @pytest.mark.parametrize("forget_to_enter_pipe", [True, False])
+ async def test_ws_callback_error(
+ self,
+ tool_4_test_fixture_factory: Tool4TestFixtureFactory,
+ forget_to_enter_pipe: bool,
+ ) -> None:
+ """Test reverse websocket proxy with callback that will raise error."""
+ msg = "foo"
+ raise_exception_event = anyio.Event()
+
+ async def client_to_server_callback(
+ pipe_context: CallbackPipeContextType[str],
+ ) -> None:
+ """We do nothing, just raise exception."""
+ if not forget_to_enter_pipe:
+ # NOTE: we must ensure that exit the context manager to close the pipe,
+ # or will get dead lock.
+ with pipe_context as (_sender, _receiver):
+ await raise_exception_event.wait()
+ raise Exception()
+ else:
+ # Here, we forget to enter the pipe context, so the context never be exited.
+ # Because there are some mitigation measures, we will not get dead lock,
+ # but we still get warning.
+ await raise_exception_event.wait()
+ raise Exception()
+
+ class NullContext:
+ def __init__(self, *args: object, **kwargs: object) -> None:
+ pass
+
+ def __enter__(self, *args: object, **kwargs: object) -> None:
+ pass
+
+ def __exit__(self, *args: object, **kwargs: object) -> None:
+ pass
+
+ def null_collector() -> None:
+ pass
+
+ if not forget_to_enter_pipe:
+ # use `gc.collect` to make sure `__del__` method of pipe context be called,
+ # so that we can get the warning.
+ warning_capturer, gc_collector = pytest.warns, gc.collect
+ else:
+ warning_capturer, gc_collector = NullContext, null_collector
+
+ with warning_capturer(
+ RuntimeWarning,
+ match="You never exit the pipe context, it may cause a deadlock.",
+ ):
+ tool_4_test_fixture = await tool_4_test_fixture_factory(
+ callback_ws_app_factory_builder(
+ client_to_server_callback=client_to_server_callback,
+ )
+ )
+ proxy_server_base_url = tool_4_test_fixture.proxy_server_base_url
+ client_for_conn_to_proxy_server = (
+ tool_4_test_fixture.client_for_conn_to_proxy_server
+ )
+ get_request = tool_4_test_fixture.get_request
+
+ async with aconnect_ws(
+ proxy_server_base_url + "echo_text", client_for_conn_to_proxy_server
+ ) as ws:
+ await ws.send_text(msg)
+ raise_exception_event.set()
+
+ # client ws has disconnected
+ with pytest.raises(httpx_ws.WebSocketDisconnect):
+ await ws.receive_text()
+
+ target_starlette_ws = get_request()
+ assert isinstance(
+ target_starlette_ws, starlette_websockets_module.WebSocket
+ )
+ # test target ws has disconnected
+ with pytest.raises(RuntimeError):
+ await target_starlette_ws.receive_text()
+
+ gc_collector()