diff --git a/examples/servers/simple-auth/mcp_simple_auth/server.py b/examples/servers/simple-auth/mcp_simple_auth/server.py index 2f1e4086f..794353d60 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/server.py +++ b/examples/servers/simple-auth/mcp_simple_auth/server.py @@ -10,7 +10,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict from starlette.exceptions import HTTPException from starlette.requests import Request -from starlette.responses import JSONResponse, RedirectResponse, Response +from starlette.responses import JSONResponse, RedirectResponse, Response, HTMLResponse from mcp.server.auth.middleware.auth_context import get_access_token from mcp.server.auth.provider import ( @@ -25,6 +25,7 @@ from mcp.server.fastmcp.server import FastMCP from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.auth import OAuthClientInformationFull, OAuthToken +from urllib.parse import urlencode logger = logging.getLogger(__name__) @@ -98,16 +99,23 @@ async def authorize( "client_id": client.client_id, } - # Build GitHub authorization URL - auth_url = ( - f"{self.settings.github_auth_url}" - f"?client_id={self.settings.github_client_id}" - f"&redirect_uri={self.settings.github_callback_path}" - f"&scope={self.settings.github_scope}" - f"&state={state}" - ) + # Return our custom consent endpoint, which will then redirect to Github + + # Extract scopes - use default MCP scope if none provided + scopes = params.scopes or [self.settings.mcp_scope] + scopes_string = " ".join(scopes) if isinstance(scopes, list) else str(scopes) + + consent_params = { + "client_id": client.client_id, + "redirect_uri": str(params.redirect_uri), + "state": state, + "scopes": scopes_string, + "code_challenge": params.code_challenge or "", + "response_type": "code" + } - return auth_url + consent_url = f"{self.settings.server_url}consent?{urlencode(consent_params)}" + return consent_url async def handle_github_callback(self, code: str, state: str) -> str: """Handle GitHub OAuth callback.""" @@ -255,6 +263,226 @@ async def revoke_token( del self.tokens[token] +class ConsentHandler: + + def __init__(self, provider: SimpleGitHubOAuthProvider, settings: ServerSettings, path: str): + self.provider: SimpleGitHubOAuthProvider = provider + self.settings: ServerSettings = settings + self.client_consent: dict[str, bool] = {} + self.path = path + + async def handle(self, request: Request) -> Response: + # This handles both showing the consent form (GET) and processing consent (POST) + if request.method == "GET": + # Show consent form + return await self._show_consent_form(request) + elif request.method == "POST": + # Process consent + return await self._process_consent(request) + else: + return HTMLResponse(status_code=405, content="Method not allowed") + + async def _show_consent_form(self, request: Request) -> HTMLResponse: + client_id = request.query_params.get("client_id", "") + redirect_uri = request.query_params.get("redirect_uri", "") + # TODO: address csrf + state = request.query_params.get("state", "") + scopes = request.query_params.get("scopes", "") + code_challenge = request.query_params.get("code_challenge", "") + response_type = request.query_params.get("response_type", "") + + # Get client info to display client_name + client_name = client_id # Default to client_id if we can't get the client + if client_id: + client = await self.provider.get_client(client_id) + if client and hasattr(client, 'client_name'): + client_name = client.client_name + + target_url = self.path + + # TODO: allow skipping consent if we've already approved this client ID + + # Create a simple consent form + html_content = f""" + + + + Authorization Required + + + + + + + + +""" + return HTMLResponse(content=html_content) + + async def _process_consent(self, request: Request) -> RedirectResponse | HTMLResponse: + form_data = await request.form() + action = form_data.get("action") + state = form_data.get("state") + + if action == "approve": + # Grant consent and continue with authorization + client_id = form_data.get("client_id") + if client_id: + client = await self.provider.get_client(client_id) + if client: + self.client_consent[client.client_id] = True + + auth_url = ( + f"{self.settings.github_auth_url}" + f"?client_id={self.settings.github_client_id}" + f"&redirect_uri={self.settings.github_callback_path}" + f"&scope={self.settings.github_scope}" + f"&state={state}" + ) + + return RedirectResponse( + # TODO: get this passed in + url=auth_url, + status_code=302, + headers={"Cache-Control": "no-store"}, + ) + else: + # User denied consent + redirect_uri = form_data.get("redirect_uri") + state = form_data.get("state") + + error_params = { + "error": "access_denied", + "error_description": "User denied the authorization request" + } + if state: + error_params["state"] = state + + if redirect_uri: + return RedirectResponse( + url=f"{redirect_uri}?{urlencode(error_params)}", + status_code=302, + headers={"Cache-Control": "no-store"}, + ) + else: + return HTMLResponse( + status_code=400, + content=f"Access denied: {error_params['error_description']}" + ) + + def _format_scopes(self, scopes: str) -> str: + if not scopes: + return "

No specific permissions requested

" + + scope_list = scopes.split() + if not scope_list: + return "

No specific permissions requested

" + + scope_html = "" + for scope in scope_list: + scope_html += f'
{scope}
' + + return scope_html + + + + def create_simple_mcp_server(settings: ServerSettings) -> FastMCP: """Create a simple FastMCP server with GitHub OAuth.""" oauth_provider = SimpleGitHubOAuthProvider(settings) @@ -279,6 +507,13 @@ def create_simple_mcp_server(settings: ServerSettings) -> FastMCP: auth=auth_settings, ) + consent_path = "/consent" + consent_handler = ConsentHandler(provider=oauth_provider, settings=settings, path=consent_path) + + @app.custom_route(consent_path, methods=["GET", "POST"]) + async def example_consent_handler(request: Request) -> Response: + return await consent_handler.handle(request) + @app.custom_route("/github/callback", methods=["GET"]) async def github_callback_handler(request: Request) -> Response: """Handle GitHub OAuth callback."""