diff --git a/src/firebase_functions/https_fn.py b/src/firebase_functions/https_fn.py index 84c5659..45cdc5e 100644 --- a/src/firebase_functions/https_fn.py +++ b/src/firebase_functions/https_fn.py @@ -280,7 +280,7 @@ class AuthData: The interface for Auth tokens verified in Callable functions """ - uid: str + uid: str | None """ User ID of the ID token. """ @@ -346,8 +346,10 @@ class CallableRequest(_typing.Generic[_core.T]): _C2 = _typing.Callable[[CallableRequest[_typing.Any]], _typing.Any] -def _on_call_handler(func: _C2, request: Request, - enforce_app_check: bool) -> Response: +def _on_call_handler(func: _C2, + request: Request, + enforce_app_check: bool, + verify_token: bool = True) -> Response: try: if not _util.valid_on_call_request(request): _logging.error("Invalid request, unable to process.") @@ -357,7 +359,8 @@ def _on_call_handler(func: _C2, request: Request, data=_json.loads(request.data)["data"], ) - token_status = _util.on_call_check_tokens(request) + token_status = _util.on_call_check_tokens(request, + verify_token=verify_token) if token_status.auth == _util.OnCallTokenState.INVALID: raise HttpsError(FunctionsErrorCode.UNAUTHENTICATED, @@ -377,8 +380,10 @@ def _on_call_handler(func: _C2, request: Request, if token_status.auth_token is not None: context = _dataclasses.replace( context, - auth=AuthData(token_status.auth_token["uid"], - token_status.auth_token), + auth=AuthData( + token_status.auth_token["uid"] + if "uid" in token_status.auth_token else None, + token_status.auth_token), ) instance_id = request.headers.get("Firebase-Instance-ID-Token") @@ -399,7 +404,7 @@ def _on_call_handler(func: _C2, request: Request, # pylint: disable=broad-except except Exception as err: if not isinstance(err, HttpsError): - _logging.error("Unhandled error", err) + _logging.error("Unhandled error: %s", err) err = HttpsError(FunctionsErrorCode.INTERNAL, "INTERNAL") status = err._http_error_code.status return _make_response(_jsonify(error=err._as_dict()), status) diff --git a/src/firebase_functions/private/util.py b/src/firebase_functions/private/util.py index 0997f8d..4a939af 100644 --- a/src/firebase_functions/private/util.py +++ b/src/firebase_functions/private/util.py @@ -15,8 +15,10 @@ Module for internal utilities. """ +import base64 import os as _os import json as _json +import re as _re import typing as _typing import dataclasses as _dataclasses import datetime as _dt @@ -29,6 +31,9 @@ P = _typing.ParamSpec("P") R = _typing.TypeVar("R") +JWT_REGEX = _re.compile( + r"^[a-zA-Z0-9\-_=]+?\.[a-zA-Z0-9\-_=]+?\.([a-zA-Z0-9\-_=]+)?$") + class Sentinel: """Internal class for RESET_VALUE.""" @@ -204,9 +209,13 @@ def as_dict(self) -> dict: def _on_call_check_auth_token( - request: _Request + request: _Request, + verify_token: bool = True, ) -> None | _typing.Literal[OnCallTokenState.INVALID] | dict[str, _typing.Any]: - """Validates the auth token in a callable request.""" + """ + Validates the auth token in a callable request. + If verify_token is False, the token will be decoded without verification. + """ authorization = request.headers.get("Authorization") if authorization is None: return None @@ -215,13 +224,15 @@ def _on_call_check_auth_token( return OnCallTokenState.INVALID try: id_token = authorization.replace("Bearer ", "") - auth_token = _auth.verify_id_token(id_token) + if verify_token: + auth_token = _auth.verify_id_token(id_token) + else: + auth_token = _unsafe_decode_id_token(id_token) return auth_token # pylint: disable=broad-except except Exception as err: _logging.error(f"Error validating token: {err}") return OnCallTokenState.INVALID - return OnCallTokenState.INVALID def _on_call_check_app_token( @@ -240,23 +251,44 @@ def _on_call_check_app_token( return OnCallTokenState.INVALID -def on_call_check_tokens(request: _Request,) -> _OnCallTokenVerification: +def _unsafe_decode_id_token(token: str): + # Check if the token matches the JWT pattern + if not JWT_REGEX.match(token): + return {} + + # Split the token by '.' and decode each component from base64 + components = [base64.urlsafe_b64decode(f"{s}==") for s in token.split(".")] + + # Attempt to parse the payload (second component) as JSON + payload = components[1].decode("utf-8") + try: + payload = _json.loads(payload) + except _json.JSONDecodeError: + # If there's an error during parsing, ignore it and return the payload as is + pass + + return payload + + +def on_call_check_tokens(request: _Request, + verify_token: bool = True) -> _OnCallTokenVerification: """Check tokens""" verifications = _OnCallTokenVerification() - auth_token = _on_call_check_auth_token(request) + auth_token = _on_call_check_auth_token(request, verify_token=verify_token) if auth_token is None: verifications.auth = OnCallTokenState.MISSING elif isinstance(auth_token, dict): verifications.auth = OnCallTokenState.VALID verifications.auth_token = auth_token - app_token = _on_call_check_app_token(request) - if app_token is None: - verifications.app = OnCallTokenState.MISSING - elif isinstance(app_token, dict): - verifications.app = OnCallTokenState.VALID - verifications.app_token = app_token + if verify_token: + app_token = _on_call_check_app_token(request) + if app_token is None: + verifications.app = OnCallTokenState.MISSING + elif isinstance(app_token, dict): + verifications.app = OnCallTokenState.VALID + verifications.app_token = app_token log_payload = { **verifications.as_dict(), @@ -266,7 +298,7 @@ def on_call_check_tokens(request: _Request,) -> _OnCallTokenVerification: } errs = [] - if verifications.app == OnCallTokenState.INVALID: + if verify_token and verifications.app == OnCallTokenState.INVALID: errs.append(("AppCheck token was rejected.", log_payload)) if verifications.auth == OnCallTokenState.INVALID: diff --git a/src/firebase_functions/tasks_fn.py b/src/firebase_functions/tasks_fn.py index 2d366b6..e0ecf3b 100644 --- a/src/firebase_functions/tasks_fn.py +++ b/src/firebase_functions/tasks_fn.py @@ -53,7 +53,10 @@ def on_task_dispatched_decorator(func: _C): @_functools.wraps(func) def on_task_dispatched_wrapped(request: Request) -> Response: - return _on_call_handler(func, request, enforce_app_check=False) + return _on_call_handler(func, + request, + enforce_app_check=False, + verify_token=False) _util.set_func_endpoint_attr( on_task_dispatched_wrapped, diff --git a/tests/test_tasks_fn.py b/tests/test_tasks_fn.py index 0e1293a..b52bd0f 100644 --- a/tests/test_tasks_fn.py +++ b/tests/test_tasks_fn.py @@ -68,3 +68,38 @@ def example(request: CallableRequest[object]) -> str: response.get_data(as_text=True), '{"result":"Hello World"}\n', ) + + def test_token_is_decoded(self): + """ + Test that the token is decoded instead of verifying auth first. + """ + app = Flask(__name__) + + @on_task_dispatched() + def example(request: CallableRequest[object]) -> str: + auth = request.auth + # Make mypy happy + if auth is None: + self.fail("Auth is None") + return "No Auth" + self.assertEqual(auth.token["sub"], "firebase") + self.assertEqual(auth.token["name"], "John Doe") + return "Hello World" + + with app.test_request_context("/"): + # pylint: disable=line-too-long + test_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJmaXJlYmFzZSIsIm5hbWUiOiJKb2huIERvZSJ9.74A24Y821E7CZx8aYCsCKo0Y-W0qXwqME-14QlEMcB0" + environ = EnvironBuilder( + method="POST", + headers={ + "Authorization": f"Bearer {test_token}" + }, + json={ + "data": { + "test": "value" + }, + }, + ).get_environ() + request = Request(environ) + response = example(request) + self.assertEqual(response.status_code, 200) diff --git a/tests/test_util.py b/tests/test_util.py index 0a8dcbb..cb13d30 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -15,7 +15,7 @@ Internal utils tests. """ from os import environ, path -from firebase_functions.private.util import firebase_config, microsecond_timestamp_conversion, nanoseconds_timestamp_conversion, get_precision_timestamp, normalize_path, deep_merge, PrecisionTimestamp, second_timestamp_conversion +from firebase_functions.private.util import firebase_config, microsecond_timestamp_conversion, nanoseconds_timestamp_conversion, get_precision_timestamp, normalize_path, deep_merge, PrecisionTimestamp, second_timestamp_conversion, _unsafe_decode_id_token import datetime as _dt test_bucket = "python-functions-testing.appspot.com" @@ -184,3 +184,11 @@ def test_does_not_modify_originals(): deep_merge(dict1, dict2) assert dict1["baz"]["answer"] == 42 assert dict2["baz"]["answer"] == 33 + + +def test_unsafe_decode_token(): + # pylint: disable=line-too-long + test_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJmaXJlYmFzZSIsIm5hbWUiOiJKb2huIERvZSJ9.74A24Y821E7CZx8aYCsCKo0Y-W0qXwqME-14QlEMcB0" + result = _unsafe_decode_id_token(test_token) + assert result["sub"] == "firebase" + assert result["name"] == "John Doe"