diff --git a/fastapi_mcp/server.py b/fastapi_mcp/server.py index fbff6e8..4b9c0c2 100644 --- a/fastapi_mcp/server.py +++ b/fastapi_mcp/server.py @@ -2,7 +2,7 @@ import httpx from typing import Dict, Optional, Any, List, Union -from fastapi import FastAPI, Request, APIRouter +from fastapi import FastAPI, Request, APIRouter, BackgroundTasks from fastapi.openapi.utils import get_openapi from mcp.server.lowlevel.server import Server import mcp.types as types @@ -183,8 +183,8 @@ async def handle_mcp_connection(request: Request): # Route for MCP messages @router.post(f"{mount_path}/messages/", include_in_schema=False, operation_id="mcp_messages") - async def handle_post_message(request: Request): - return await sse_transport.handle_fastapi_post_message(request) + async def handle_post_message(request: Request, background_tasks: BackgroundTasks): + return await sse_transport.handle_fastapi_post_message(request, background_tasks) # HACK: If we got a router and not a FastAPI instance, we need to re-include the router so that # FastAPI will pick up the new routes we added. The problem with this approach is that we assume diff --git a/fastapi_mcp/transport/sse.py b/fastapi_mcp/transport/sse.py index 9adb725..af2f7ab 100644 --- a/fastapi_mcp/transport/sse.py +++ b/fastapi_mcp/transport/sse.py @@ -14,7 +14,7 @@ class FastApiSseTransport(SseServerTransport): - async def handle_fastapi_post_message(self, request: Request) -> Response: + async def handle_fastapi_post_message(self, request: Request, background_tasks: BackgroundTasks) -> Response: """ A reimplementation of the handle_post_message method of SseServerTransport that integrates better with FastAPI. @@ -60,8 +60,7 @@ async def handle_fastapi_post_message(self, request: Request) -> Response: logger.debug(f"Validated client message: {message}") except ValidationError as err: logger.error(f"Failed to parse message: {err}") - # Create background task to send error - background_tasks = BackgroundTasks() + # Create background task to send error, to avoid ASGI race conditions background_tasks.add_task(self._send_message_safely, writer, err) response = JSONResponse(content={"error": "Could not parse message"}, status_code=400) response.background = background_tasks @@ -70,8 +69,7 @@ async def handle_fastapi_post_message(self, request: Request) -> Response: logger.error(f"Error processing request body: {e}") raise HTTPException(status_code=400, detail="Invalid request body") - # Create background task to send message - background_tasks = BackgroundTasks() + # Create background task to send message, to avoid ASGI race conditions background_tasks.add_task(self._send_message_safely, writer, message) logger.debug("Accepting message, will send in background")