From f3b948ed3a16f355f01c39422e9d708c92899a98 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Thu, 15 May 2025 13:01:30 -0700 Subject: [PATCH] feat: support audio content Adds support for a distinct AudioContent type as defined in the specification. This happens to share a structure with ImageContent, but should probably be distinguished for parity with the specification itself. --- .../server.py | 7 +++++- .../mcp_simple_streamablehttp/server.py | 7 +++++- .../simple-tool/mcp_simple_tool/server.py | 11 +++++++-- src/mcp/server/fastmcp/prompts/base.py | 4 ++-- src/mcp/server/fastmcp/server.py | 7 +++--- src/mcp/server/lowlevel/server.py | 5 +++- src/mcp/types.py | 23 +++++++++++++++---- tests/issues/test_88_random_error.py | 3 ++- tests/server/fastmcp/test_server.py | 12 ++++++---- 9 files changed, 60 insertions(+), 19 deletions(-) diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py index f718df801..569f0ccfc 100644 --- a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py +++ b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py @@ -43,7 +43,12 @@ def main( @app.call_tool() async def call_tool( name: str, arguments: dict - ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: + ) -> list[ + types.TextContent + | types.ImageContent + | types.AudioContent + | types.EmbeddedResource + ]: ctx = app.request_context interval = arguments.get("interval", 1.0) count = arguments.get("count", 5) diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index 1a76097b5..605fe916b 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -47,7 +47,12 @@ def main( @app.call_tool() async def call_tool( name: str, arguments: dict - ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: + ) -> list[ + types.TextContent + | types.ImageContent + | types.AudioContent + | types.EmbeddedResource + ]: ctx = app.request_context interval = arguments.get("interval", 1.0) count = arguments.get("count", 5) diff --git a/examples/servers/simple-tool/mcp_simple_tool/server.py b/examples/servers/simple-tool/mcp_simple_tool/server.py index 5f4e28bb7..62c954743 100644 --- a/examples/servers/simple-tool/mcp_simple_tool/server.py +++ b/examples/servers/simple-tool/mcp_simple_tool/server.py @@ -7,7 +7,9 @@ async def fetch_website( url: str, -) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: +) -> list[ + types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource +]: headers = { "User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)" } @@ -31,7 +33,12 @@ def main(port: int, transport: str) -> int: @app.call_tool() async def fetch_tool( name: str, arguments: dict - ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: + ) -> list[ + types.TextContent + | types.ImageContent + | types.AudioContent + | types.EmbeddedResource + ]: if name != "fetch": raise ValueError(f"Unknown tool: {name}") if "url" not in arguments: diff --git a/src/mcp/server/fastmcp/prompts/base.py b/src/mcp/server/fastmcp/prompts/base.py index aa3d1eac9..33bf68025 100644 --- a/src/mcp/server/fastmcp/prompts/base.py +++ b/src/mcp/server/fastmcp/prompts/base.py @@ -7,9 +7,9 @@ import pydantic_core from pydantic import BaseModel, Field, TypeAdapter, validate_call -from mcp.types import EmbeddedResource, ImageContent, TextContent +from mcp.types import AudioContent, EmbeddedResource, ImageContent, TextContent -CONTENT_TYPES = TextContent | ImageContent | EmbeddedResource +CONTENT_TYPES = TextContent | ImageContent | AudioContent | EmbeddedResource class Message(BaseModel): diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 21c31b0b3..505ee55ec 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -52,6 +52,7 @@ from mcp.shared.context import LifespanContextT, RequestContext from mcp.types import ( AnyFunction, + AudioContent, EmbeddedResource, GetPromptResult, ImageContent, @@ -271,7 +272,7 @@ def get_context(self) -> Context[ServerSession, object]: async def call_tool( self, name: str, arguments: dict[str, Any] - ) -> Sequence[TextContent | ImageContent | EmbeddedResource]: + ) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]: """Call a tool by name with arguments.""" context = self.get_context() result = await self._tool_manager.call_tool(name, arguments, context=context) @@ -871,12 +872,12 @@ async def get_prompt( def _convert_to_content( result: Any, -) -> Sequence[TextContent | ImageContent | EmbeddedResource]: +) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]: """Convert a result to a sequence of content objects.""" if result is None: return [] - if isinstance(result, TextContent | ImageContent | EmbeddedResource): + if isinstance(result, TextContent | ImageContent | AudioContent | EmbeddedResource): return [result] if isinstance(result, Image): diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 876aef817..b4ff85429 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -400,7 +400,10 @@ def decorator( ..., Awaitable[ Iterable[ - types.TextContent | types.ImageContent | types.EmbeddedResource + types.TextContent + | types.ImageContent + | types.AudioContent + | types.EmbeddedResource ] ], ], diff --git a/src/mcp/types.py b/src/mcp/types.py index d864b19da..35ae1d327 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -651,11 +651,26 @@ class ImageContent(BaseModel): model_config = ConfigDict(extra="allow") +class AudioContent(BaseModel): + """Audio content for a message.""" + + type: Literal["audio"] + data: str + """The base64-encoded audio data.""" + mimeType: str + """ + The MIME type of the audio. Different providers may support different + audio types. + """ + annotations: Annotations | None = None + model_config = ConfigDict(extra="allow") + + class SamplingMessage(BaseModel): """Describes a message issued to or received from an LLM API.""" role: Role - content: TextContent | ImageContent + content: TextContent | ImageContent | AudioContent model_config = ConfigDict(extra="allow") @@ -677,7 +692,7 @@ class PromptMessage(BaseModel): """Describes a message returned as part of a prompt.""" role: Role - content: TextContent | ImageContent | EmbeddedResource + content: TextContent | ImageContent | AudioContent | EmbeddedResource model_config = ConfigDict(extra="allow") @@ -796,7 +811,7 @@ class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]): class CallToolResult(Result): """The server's response to a tool call.""" - content: list[TextContent | ImageContent | EmbeddedResource] + content: list[TextContent | ImageContent | AudioContent | EmbeddedResource] isError: bool = False @@ -960,7 +975,7 @@ class CreateMessageResult(Result): """The client's response to a sampling/create_message request from the server.""" role: Role - content: TextContent | ImageContent + content: TextContent | ImageContent | AudioContent model: str """The name of the model that generated the message.""" stopReason: StopReason | None = None diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index 88e41d66d..9b21e4ba1 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -12,6 +12,7 @@ from mcp.server.lowlevel import Server from mcp.shared.exceptions import McpError from mcp.types import ( + AudioContent, EmbeddedResource, ImageContent, TextContent, @@ -37,7 +38,7 @@ async def test_notification_validation_error(tmp_path: Path): @server.call_tool() async def slow_tool( name: str, arg - ) -> Sequence[TextContent | ImageContent | EmbeddedResource]: + ) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]: nonlocal request_count request_count += 1 diff --git a/tests/server/fastmcp/test_server.py b/tests/server/fastmcp/test_server.py index b817761ea..71cad7e68 100644 --- a/tests/server/fastmcp/test_server.py +++ b/tests/server/fastmcp/test_server.py @@ -16,6 +16,7 @@ create_connected_server_and_client_session as client_session, ) from mcp.types import ( + AudioContent, BlobResourceContents, ImageContent, TextContent, @@ -207,10 +208,11 @@ def image_tool_fn(path: str) -> Image: return Image(path) -def mixed_content_tool_fn() -> list[TextContent | ImageContent]: +def mixed_content_tool_fn() -> list[TextContent | ImageContent | AudioContent]: return [ TextContent(type="text", text="Hello"), ImageContent(type="image", data="abc", mimeType="image/png"), + AudioContent(type="audio", data="def", mimeType="audio/wav"), ] @@ -312,14 +314,16 @@ async def test_tool_mixed_content(self): mcp.add_tool(mixed_content_tool_fn) async with client_session(mcp._mcp_server) as client: result = await client.call_tool("mixed_content_tool_fn", {}) - assert len(result.content) == 2 - content1 = result.content[0] - content2 = result.content[1] + assert len(result.content) == 3 + content1, content2, content3 = result.content assert isinstance(content1, TextContent) assert content1.text == "Hello" assert isinstance(content2, ImageContent) assert content2.mimeType == "image/png" assert content2.data == "abc" + assert isinstance(content3, AudioContent) + assert content3.mimeType == "audio/wav" + assert content3.data == "def" @pytest.mark.anyio async def test_tool_mixed_list_with_image(self, tmp_path: Path):