Skip to content

add Tool.outputSchema and CallToolResult.structuredContent #685

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
addce22
work in progress attempting to implement RFC 356
davemssavage May 10, 2025
c3d4d4f
outputSchema basics implemented, more testing required also potential…
davemssavage May 11, 2025
994ebad
allow output schema to be None distinct from empty dict which makes n…
davemssavage May 11, 2025
7131de6
exclude unset to drop default None values
davemssavage May 11, 2025
ad2ec44
remove exclude_unset doesn't seem to work
davemssavage May 11, 2025
7f36822
formatting
davemssavage May 11, 2025
06681d6
add test for function with no return annotation
davemssavage May 11, 2025
9db284b
fix ruff check error
davemssavage May 11, 2025
3b28fba
fix test doc string
davemssavage May 11, 2025
6244899
tidy up if else block
davemssavage May 12, 2025
e7c6727
refactor to support approach taken in RFC 371
davemssavage May 16, 2025
ecc7146
Merge branch 'main' into outputSchema
davemssavage May 16, 2025
6d2882c
tidy up
davemssavage May 17, 2025
982f6b0
Merge branch 'outputSchema' of https://github.com/davemssavage/python…
davemssavage May 17, 2025
43ebe80
add schema checking on client side via jsonschema
davemssavage May 17, 2025
ad83eea
Tidy up to follow conventions for ToolOutputValidation as used in res…
davemssavage May 17, 2025
3261cbc
more tidy up
davemssavage May 17, 2025
b16e716
enable tool to explicitly state output schema to override automatic o…
davemssavage May 17, 2025
4d327cd
add logging
davemssavage May 17, 2025
9109577
add a no op validator, add logging and ensure cache is refreshed prop…
davemssavage May 17, 2025
1746ea1
use schema to decide what to do about content conversion
davemssavage May 18, 2025
d3986f2
ruff check/formatting fixes
davemssavage May 18, 2025
c2168f2
pyright fixes
davemssavage May 18, 2025
76d1a7f
Merge branch 'main' into outputSchema
davemssavage May 24, 2025
b738f1b
fixed failing test
davemssavage May 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
"sse-starlette>=1.6.1",
"pydantic-settings>=2.5.2",
"uvicorn>=0.23.1; sys_platform != 'emscripten'",
"jsonschema==4.23.0",
]

[project.optional-dependencies]
Expand Down
137 changes: 127 additions & 10 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import logging
from collections.abc import Awaitable, Callable
from datetime import timedelta
from typing import Any, Protocol
from typing import Any, Protocol, TypeAlias

import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from jsonschema import ValidationError, validate
from pydantic import AnyUrl, TypeAdapter

import mcp.types as types
Expand All @@ -11,6 +14,8 @@
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS

logger = logging.getLogger(__name__)

DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")


Expand Down Expand Up @@ -44,6 +49,12 @@ async def __call__(
) -> None: ...


class ToolOutputValidationFnT(Protocol):
async def __call__(
self, request: types.CallToolRequest, result: types.CallToolResult
) -> bool: ...


async def _default_message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
Expand Down Expand Up @@ -77,6 +88,25 @@ async def _default_logging_callback(
pass


ToolOutputValidatorProvider: TypeAlias = Callable[
...,
Awaitable[ToolOutputValidationFnT],
]


# this bag of spanners is required in order to
# enable the client session to be parsed to the validator
async def _python_circularity_hell(arg: Any) -> ToolOutputValidationFnT:
# in any sane version of the universe this should never happen
# of course in any sane programming language class circularity
# dependencies shouldn't be this hard to manage
raise RuntimeError(
"Help I'm stuck in python circularity hell, please send biscuits"
)


_default_tool_output_validator: ToolOutputValidatorProvider = _python_circularity_hell

ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(
types.ClientResult | types.ErrorData
)
Expand All @@ -101,6 +131,7 @@ def __init__(
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
client_info: types.Implementation | None = None,
tool_output_validator_provider: ToolOutputValidatorProvider | None = None,
) -> None:
super().__init__(
read_stream,
Expand All @@ -114,6 +145,9 @@ def __init__(
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
self._logging_callback = logging_callback or _default_logging_callback
self._message_handler = message_handler or _default_message_handler
self._tool_output_validator_provider = (
tool_output_validator_provider or _default_tool_output_validator
)

async def initialize(self) -> types.InitializeResult:
sampling = types.SamplingCapability()
Expand Down Expand Up @@ -154,6 +188,8 @@ async def initialize(self) -> types.InitializeResult:
)
)

self._tool_output_validator = await self._tool_output_validator_provider(self)

return result

async def send_ping(self) -> types.EmptyResult:
Expand Down Expand Up @@ -275,24 +311,33 @@ async def call_tool(
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
progress_callback: ProgressFnT | None = None,
validate_result: bool = True,
) -> types.CallToolResult:
"""Send a tools/call request with optional progress callback support."""

return await self.send_request(
types.ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(
name=name,
arguments=arguments,
),
)
request = types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(
name=name,
arguments=arguments,
),
)

result = await self.send_request(
types.ClientRequest(request),
types.CallToolResult,
request_read_timeout_seconds=read_timeout_seconds,
progress_callback=progress_callback,
)

if validate_result:
valid = await self._tool_output_validator(request, result)

if not valid:
raise RuntimeError("Server responded with invalid result: " f"{result}")
# not validating or is valid
return result

async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResult:
"""Send a prompts/list request."""
return await self.send_request(
Expand Down Expand Up @@ -412,3 +457,75 @@ async def _received_notification(
await self._logging_callback(params)
case _:
pass


class NoOpToolOutputValidator(ToolOutputValidationFnT):
async def __call__(
self, request: types.CallToolRequest, result: types.CallToolResult
) -> bool:
return True


class SimpleCachingToolOutputValidator(ToolOutputValidationFnT):
_schema_cache: dict[str, dict[str, Any] | bool]

def __init__(self, session: ClientSession):
self._session = session
self._schema_cache = {}
self._refresh_cache = True

async def __call__(
self, request: types.CallToolRequest, result: types.CallToolResult
) -> bool:
if result.isError:
# allow errors to be propagated
return True
else:
if self._refresh_cache:
await self._refresh_schema_cache()

schema = self._schema_cache.get(request.params.name)

if schema is None:
raise RuntimeError(f"Unknown tool {request.params.name}")
elif schema is False:
# no schema
logging.debug("No schema found checking structuredContent is empty")
return result.structuredContent is None
else:
try:
# TODO opportunity to build jsonschema.protocol.Validator
# and reuse rather than build every time
validate(result.structuredContent, schema)
return True
except ValidationError as e:
logging.exception(e)
return False

async def _refresh_schema_cache(self):
cursor = None
first = True
self._schema_cache = {}
while first or cursor is not None:
first = False
tools_result = await self._session.list_tools(cursor)
for tool in tools_result.tools:
# store a flag to be able to later distinguish between
# no schema for tool and unknown tool which can't be verified
schema_or_flag = (
False if tool.outputSchema is None else tool.outputSchema
)
self._schema_cache[tool.name] = schema_or_flag
cursor = tools_result.nextCursor
continue

self._refresh_cache = False


async def _escape_from_circular_python_hell(
session: ClientSession,
) -> ToolOutputValidationFnT:
return SimpleCachingToolOutputValidator(session)


_default_tool_output_validator = _escape_from_circular_python_hell
55 changes: 36 additions & 19 deletions src/mcp/server/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def run(
def _setup_handlers(self) -> None:
"""Set up core MCP protocol handlers."""
self._mcp_server.list_tools()(self.list_tools)
self._mcp_server.call_tool()(self.call_tool)
self._mcp_server.call_tool()(self.call_tool, self._tool_manager.get_schema)
self._mcp_server.list_resources()(self.list_resources)
self._mcp_server.read_resource()(self.read_resource)
self._mcp_server.list_prompts()(self.list_prompts)
Expand All @@ -255,6 +255,7 @@ async def list_tools(self) -> list[MCPTool]:
name=info.name,
description=info.description,
inputSchema=info.parameters,
outputSchema=info.output,
annotations=info.annotations,
)
for info in tools
Expand All @@ -277,7 +278,8 @@ async def call_tool(
"""Call a tool by name with arguments."""
context = self.get_context()
result = await self._tool_manager.call_tool(name, arguments, context=context)
converted_result = _convert_to_content(result)
schema = self._tool_manager.get_schema(name)
converted_result = _convert_to_content(result, schema)
return converted_result

async def list_resources(self) -> list[MCPResource]:
Expand Down Expand Up @@ -325,6 +327,7 @@ def add_tool(
name: str | None = None,
description: str | None = None,
annotations: ToolAnnotations | None = None,
output_schema: dict[str, Any] | None = None,
) -> None:
"""Add a tool to the server.

Expand All @@ -336,9 +339,15 @@ def add_tool(
name: Optional name for the tool (defaults to function name)
description: Optional description of what the tool does
annotations: Optional ToolAnnotations providing additional tool information
output_schema: Optional json schema that the tool should output. If
not specified the schema will be inferred automatically
"""
self._tool_manager.add_tool(
fn, name=name, description=description, annotations=annotations
fn,
name=name,
description=description,
annotations=annotations,
output_schema=output_schema,
)

def tool(
Expand Down Expand Up @@ -872,25 +881,33 @@ async def get_prompt(


def _convert_to_content(
result: Any,
result: Any, schema: dict[str, Any] | None
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
"""Convert a result to a sequence of content objects."""
if result is None:
return []

if isinstance(result, TextContent | ImageContent | EmbeddedResource):
return [result]

if isinstance(result, Image):
return [result.to_image_content()]

if isinstance(result, list | tuple):
return list(chain.from_iterable(_convert_to_content(item) for item in result)) # type: ignore[reportUnknownVariableType]
if schema is None:
"""Convert a result to a sequence of content objects."""
if result is None:
return []

if isinstance(result, TextContent | ImageContent | EmbeddedResource):
return [result]

if isinstance(result, Image):
return [result.to_image_content()]

if isinstance(result, list | tuple):
return list(
chain.from_iterable(
_convert_to_content(item, schema)
for item in result # type: ignore[reportUnknownVariableType]
)
)

if not isinstance(result, str):
result = pydantic_core.to_json(result, fallback=str, indent=2).decode()
if not isinstance(result, str):
result = pydantic_core.to_json(result, fallback=str, indent=2).decode()

return [TextContent(type="text", text=result)]
return [TextContent(type="text", text=result)]
else:
return result


class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
Expand Down
8 changes: 8 additions & 0 deletions src/mcp/server/fastmcp/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ class Tool(BaseModel):
name: str = Field(description="Name of the tool")
description: str = Field(description="Description of what the tool does")
parameters: dict[str, Any] = Field(description="JSON schema for tool parameters")
output: dict[str, Any] | None = Field(
description="JSON schema for tool output",
default=None,
)
fn_metadata: FuncMetadata = Field(
description="Metadata about the function including a pydantic model for tool"
" arguments"
Expand All @@ -44,6 +48,7 @@ def from_function(
description: str | None = None,
context_kwarg: str | None = None,
annotations: ToolAnnotations | None = None,
output_schema: dict[str, Any] | None = None,
) -> Tool:
"""Create a Tool from a function."""
from mcp.server.fastmcp.server import Context
Expand All @@ -68,14 +73,17 @@ def from_function(
func_arg_metadata = func_metadata(
fn,
skip_names=[context_kwarg] if context_kwarg is not None else [],
output_schema=output_schema,
)
parameters = func_arg_metadata.arg_model.model_json_schema()
output = func_arg_metadata.output_schema

return cls(
fn=fn,
name=func_name,
description=func_doc,
parameters=parameters,
output=output,
fn_metadata=func_arg_metadata,
is_async=is_async,
context_kwarg=context_kwarg,
Expand Down
13 changes: 12 additions & 1 deletion src/mcp/server/fastmcp/tools/tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,15 @@ def add_tool(
name: str | None = None,
description: str | None = None,
annotations: ToolAnnotations | None = None,
output_schema: dict[str, Any] | None = None,
) -> Tool:
"""Add a tool to the server."""
tool = Tool.from_function(
fn, name=name, description=description, annotations=annotations
fn,
name=name,
description=description,
annotations=annotations,
output_schema=output_schema,
)
existing = self._tools.get(tool.name)
if existing:
Expand All @@ -73,3 +78,9 @@ async def call_tool(
raise ToolError(f"Unknown tool: {name}")

return await tool.run(arguments, context=context)

def get_schema(self, name: str) -> dict[str, Any] | None:
tool = self.get_tool(name)
if not tool:
raise ToolError(f"Unknown tool: {name}")
return tool.output
Loading
Loading