7
7
from dataclasses import dataclass
8
8
from typing import Any , Callable
9
9
10
+ from exceptiongroup import BaseExceptionGroup
10
11
from starlette .applications import Starlette
11
12
from starlette .middleware .cors import CORSMiddleware
12
13
from starlette .requests import Request
@@ -137,8 +138,6 @@ async def serve_index(request: Request) -> HTMLResponse:
137
138
def _setup_single_view_dispatcher_route (
138
139
options : Options , app : Starlette , component : RootComponentConstructor
139
140
) -> None :
140
- @app .websocket_route (str (STREAM_PATH ))
141
- @app .websocket_route (f"{ STREAM_PATH } /{{path:path}}" )
142
141
async def model_stream (socket : WebSocket ) -> None :
143
142
await socket .accept ()
144
143
send , recv = _make_send_recv_callbacks (socket )
@@ -162,8 +161,16 @@ async def model_stream(socket: WebSocket) -> None:
162
161
send ,
163
162
recv ,
164
163
)
165
- except WebSocketDisconnect as error :
166
- logger .info (f"WebSocket disconnect: { error .code } " )
164
+ except BaseExceptionGroup as egroup :
165
+ for e in egroup .exceptions :
166
+ if isinstance (e , WebSocketDisconnect ):
167
+ logger .info (f"WebSocket disconnect: { e .code } " )
168
+ break
169
+ else :
170
+ raise
171
+
172
+ app .add_websocket_route (str (STREAM_PATH ), model_stream )
173
+ app .add_websocket_route (f"{ STREAM_PATH } /{{path:path}}" , model_stream )
167
174
168
175
169
176
def _make_send_recv_callbacks (
0 commit comments