diff --git a/.pylintrc b/.pylintrc index ec7ab62..0f0f4e2 100644 --- a/.pylintrc +++ b/.pylintrc @@ -146,6 +146,7 @@ disable=abstract-method, wrong-import-order, xrange-builtin, zip-builtin-not-iterating, + import-outside-toplevel, [REPORTS] diff --git a/docs/generate.sh b/docs/generate.sh index 43f0c9d..28be277 100755 --- a/docs/generate.sh +++ b/docs/generate.sh @@ -86,6 +86,7 @@ PY_MODULES='firebase_functions firebase_functions.core firebase_functions.db_fn firebase_functions.https_fn + firebase_functions.identity_fn firebase_functions.options firebase_functions.params firebase_functions.pubsub_fn diff --git a/samples/identity/.firebaserc b/samples/identity/.firebaserc new file mode 100644 index 0000000..ad27d4b --- /dev/null +++ b/samples/identity/.firebaserc @@ -0,0 +1,5 @@ +{ + "projects": { + "default": "python-functions-testing" + } +} diff --git a/samples/identity/.gitignore b/samples/identity/.gitignore new file mode 100644 index 0000000..dbb58ff --- /dev/null +++ b/samples/identity/.gitignore @@ -0,0 +1,66 @@ +# Logs +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +firebase-debug.log* +firebase-debug.*.log* + +# Firebase cache +.firebase/ + +# Firebase config + +# Uncomment this if you'd like others to create their own Firebase project. +# For a team working on the same Firebase project(s), it is recommended to leave +# it commented so all members can deploy to the same project(s) in .firebaserc. +# .firebaserc + +# Runtime data +pids +*.pid +*.seed +*.pid.lock + +# Directory for instrumented libs generated by jscoverage/JSCover +lib-cov + +# Coverage directory used by tools like istanbul +coverage + +# nyc test coverage +.nyc_output + +# Grunt intermediate storage (http://gruntjs.com/creating-plugins#storing-task-files) +.grunt + +# Bower dependency directory (https://bower.io/) +bower_components + +# node-waf configuration +.lock-wscript + +# Compiled binary addons (http://nodejs.org/api/addons.html) +build/Release + +# Dependency directories +node_modules/ + +# Optional npm cache directory +.npm + +# Optional eslint cache +.eslintcache + +# Optional REPL history +.node_repl_history + +# Output of 'npm pack' +*.tgz + +# Yarn Integrity file +.yarn-integrity + +# dotenv environment variables file +.env diff --git a/samples/identity/__init__.py b/samples/identity/__init__.py new file mode 100644 index 0000000..2340b04 --- /dev/null +++ b/samples/identity/__init__.py @@ -0,0 +1,3 @@ +# Required to avoid a 'duplicate modules' mypy error +# in monorepos that have multiple main.py files. +# https://github.com/python/mypy/issues/4008 diff --git a/samples/identity/client/index.html b/samples/identity/client/index.html new file mode 100644 index 0000000..a13f07f --- /dev/null +++ b/samples/identity/client/index.html @@ -0,0 +1,24 @@ + \ No newline at end of file diff --git a/samples/identity/firebase.json b/samples/identity/firebase.json new file mode 100644 index 0000000..7bbd899 --- /dev/null +++ b/samples/identity/firebase.json @@ -0,0 +1,11 @@ +{ + "functions": [ + { + "source": "functions", + "codebase": "default", + "ignore": [ + "venv" + ] + } + ] +} diff --git a/samples/identity/functions/.gitignore b/samples/identity/functions/.gitignore new file mode 100644 index 0000000..34cef6b --- /dev/null +++ b/samples/identity/functions/.gitignore @@ -0,0 +1,13 @@ +# pyenv +.python-version + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Environments +.env +.venv +venv/ +venv.bak/ +__pycache__ diff --git a/samples/identity/functions/main.py b/samples/identity/functions/main.py new file mode 100644 index 0000000..9fa44f1 --- /dev/null +++ b/samples/identity/functions/main.py @@ -0,0 +1,41 @@ +"""Firebase Cloud Functions for blocking auth functions example.""" +from firebase_functions import identity_fn + + +@identity_fn.before_user_created( + id_token=True, + access_token=True, + refresh_token=True, +) +def beforeusercreated( + event: identity_fn.AuthBlockingEvent +) -> identity_fn.BeforeCreateResponse | None: + print(event) + if not event.data.email: + return None + if "@cats.com" in event.data.email: + return identity_fn.BeforeCreateResponse(display_name="Meow!",) + if "@dogs.com" in event.data.email: + return identity_fn.BeforeCreateResponse(display_name="Woof!",) + return None + + +@identity_fn.before_user_signed_in( + id_token=True, + access_token=True, + refresh_token=True, +) +def beforeusersignedin( + event: identity_fn.AuthBlockingEvent +) -> identity_fn.BeforeSignInResponse | None: + print(event) + if not event.data.email: + return None + + if "@cats.com" in event.data.email: + return identity_fn.BeforeSignInResponse(session_claims={"emoji": "🐈"}) + + if "@dogs.com" in event.data.email: + return identity_fn.BeforeSignInResponse(session_claims={"emoji": "🐕"}) + + return None diff --git a/samples/identity/functions/requirements.txt b/samples/identity/functions/requirements.txt new file mode 100644 index 0000000..8977a41 --- /dev/null +++ b/samples/identity/functions/requirements.txt @@ -0,0 +1,8 @@ +# Not published yet, +# firebase-functions-python >= 0.0.1 +# so we use a relative path during development: +./../../../ +# Or switch to git ref for deployment testing: +# git+https://github.com/firebase/firebase-functions-python.git@main#egg=firebase-functions + +firebase-admin >= 6.0.1 diff --git a/src/firebase_functions/identity_fn.py b/src/firebase_functions/identity_fn.py new file mode 100644 index 0000000..8592f85 --- /dev/null +++ b/src/firebase_functions/identity_fn.py @@ -0,0 +1,433 @@ +# Copyright 2022 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Cloud functions to handle Eventarc events.""" + +# pylint: disable=protected-access +import typing as _typing +import functools as _functools +import datetime as _dt +import dataclasses as _dataclasses + +import firebase_functions.options as _options +import firebase_functions.private.util as _util +from flask import ( + Request as _Request, + Response as _Response, +) + + +@_dataclasses.dataclass(frozen=True) +class AuthUserInfo: + """ + User info that is part of the AuthUserRecord. + """ + uid: str + """The user identifier for the linked provider.""" + + provider_id: str + """The linked provider ID (e.g., "google.com" for the Google provider).""" + + display_name: str | None = None + """The display name for the linked provider.""" + + email: str | None = None + """The email for the linked provider.""" + + photo_url: str | None = None + """The photo URL for the linked provider.""" + + phone_number: str | None = None + """The phone number for the linked provider.""" + + +@_dataclasses.dataclass(frozen=True) +class AuthUserMetadata: + """ + Additional metadata about the user. + """ + creation_time: _dt.datetime + """The date the user was created.""" + + last_sign_in_time: _dt.datetime + """The date the user last signed in.""" + + +@_dataclasses.dataclass(frozen=True) +class AuthMultiFactorInfo: + """ + Interface representing the common properties of a user-enrolled second factor. + """ + + uid: str + """ + The ID of the enrolled second factor. This ID is unique to the user. + """ + + display_name: str | None + """ + The optional display name of the enrolled second factor. + """ + + factor_id: str + """ + The type identifier of the second factor. For SMS second factors, this is `phone`. + """ + + enrollment_time: _dt.datetime | None + """ + The optional date the second factor was enrolled. + """ + + phone_number: str | None + """ + The phone number associated with a phone second factor. + """ + + +@_dataclasses.dataclass(frozen=True) +class AuthMultiFactorSettings: + """ + The multi-factor related properties for the current user, if available. + """ + + enrolled_factors: list[AuthMultiFactorInfo] + """ + List of second factors enrolled with the current user. + """ + + +@_dataclasses.dataclass(frozen=True) +class AuthUserRecord: + """ + The UserRecord passed to auth blocking Cloud Functions from the identity platform. + """ + + uid: str + """ + The user's `uid`. + """ + + email: str | None + """ + The user's primary email, if set. + """ + + email_verified: bool + """ + Whether or not the user's primary email is verified. + """ + + display_name: str | None + """ + The user's display name. + """ + + photo_url: str | None + """ + The user's photo URL. + """ + + phone_number: str | None + """ + The user's primary phone number, if set. + """ + + disabled: bool + """ + Whether or not the user is disabled: `true` for disabled; `false` for enabled. + """ + + metadata: AuthUserMetadata + """ + Additional metadata about the user. + """ + + provider_data: list[AuthUserInfo] + """ + An array of providers (e.g., Google, Facebook) linked to the user. + """ + + password_hash: str | None + """ + The user's hashed password (base64-encoded). + """ + + password_salt: str | None + """ + The user's password salt (base64-encoded). + """ + + custom_claims: dict[str, _typing.Any] | None + """ + The user's custom claims object if available. + """ + + tenant_id: str | None + """ + The ID of the tenant the user belongs to, if available. + """ + + tokens_valid_after_time: _dt.datetime | None + """The date the user's tokens are valid after.""" + + multi_factor: AuthMultiFactorSettings | None + """The multi-factor related properties for the current user, if available.""" + + +@_dataclasses.dataclass(frozen=True) +class AdditionalUserInfo: + """ + The additional user info component of the auth event context. + """ + + provider_id: str + """The provider identifier.""" + + profile: dict[str, _typing.Any] | None + """The user's profile data as a dictionary.""" + + username: str | None + """The user's username, if available.""" + + is_new_user: bool + """A boolean indicating if the user is new or not.""" + + +@_dataclasses.dataclass(frozen=True) +class Credential: + """ + The credential component of the auth event context. + """ + + claims: dict[str, _typing.Any] | None + """The user's claims data as a dictionary.""" + + id_token: str | None + """The user's ID token.""" + + access_token: str | None + """The user's access token.""" + + refresh_token: str | None + """The user's refresh token.""" + + expiration_time: _dt.datetime | None + """The expiration time of the user's access token.""" + + secret: str | None + """The user's secret.""" + + provider_id: str + """The provider identifier.""" + + sign_in_method: str + """The user's sign-in method.""" + + +@_dataclasses.dataclass(frozen=True) +class AuthBlockingEvent: + """ + Defines an auth event for identitytoolkit v2 auth blocking events. + """ + + data: AuthUserRecord + """ + The UserRecord passed to auth blocking Cloud Functions from the identity platform. + """ + + locale: str | None + """ + The application locale. You can set the locale using the client SDK, + or by passing the locale header in the REST API. + Example: 'fr' or 'sv-SE' + """ + + event_id: str + """ + The event's unique identifier. + Example: 'rWsyPtolplG2TBFoOkkgyg' + """ + + ip_address: str + """ + The IP address of the device the end user is registering or signing in from. + Example: '114.14.200.1' + """ + + user_agent: str + """ + The user agent triggering the blocking function. + Example: 'Mozilla/5.0 (X11; Linux x86_64)' + """ + + additional_user_info: AdditionalUserInfo + """An object containing information about the user.""" + + credential: Credential | None + """An object containing information about the user's credential.""" + + timestamp: _dt.datetime + """ + The time the event was triggered.""" + + +class BeforeCreateResponse(_typing.TypedDict, total=False): + """ + The handler response type for 'before_user_created' blocking events. + """ + + display_name: str | None + """The user's display name.""" + + disabled: bool | None + """Whether or not the user is disabled.""" + + email_verified: bool | None + """Whether or not the user's primary email is verified.""" + + photo_url: str | None + """The user's photo URL.""" + + custom_claims: dict[str, _typing.Any] | None + """The user's custom claims object if available.""" + + +class BeforeSignInResponse(BeforeCreateResponse, total=False): + """ + The handler response type for 'before_user_signed_in' blocking events. + """ + + session_claims: dict[str, _typing.Any] | None + """The user's session claims object if available.""" + + +BeforeUserCreatedCallable = _typing.Callable[[AuthBlockingEvent], + BeforeCreateResponse | None] +""" +The type of the callable for 'before_user_created' blocking events. +""" + +BeforeUserSignedInCallable = _typing.Callable[[AuthBlockingEvent], + BeforeSignInResponse | None] +""" +The type of the callable for 'before_user_signed_in' blocking events. +""" + + +@_util.copy_func_kwargs(_options.BlockingOptions) +def before_user_signed_in( + **kwargs, +) -> _typing.Callable[[BeforeUserSignedInCallable], BeforeUserSignedInCallable]: + """ + Handles an event that is triggered before a user is signed in. + + Example: + + .. code-block:: python + + from firebase_functions import identity_fn + + @identity_fn.before_user_signed_in() + def example(event: identity_fn.AuthBlockingEvent) -> identity_fn.BeforeSignInResponse | None: + pass + + :param \\*\\*kwargs: Options. + :type \\*\\*kwargs: as :exc:`firebase_functions.options.BlockingOptions` + :rtype: :exc:`typing.Callable` + \\[ \\[ :exc:`firebase_functions.identity_fn.AuthBlockingEvent` \\], + :exc:`firebase_functions.identity_fn.BeforeSignInResponse` \\| `None` \\] + A function that takes a AuthBlockingEvent and optionally returns BeforeSignInResponse. + """ + options = _options.BlockingOptions(**kwargs) + + def before_user_signed_in_decorator(func: BeforeUserSignedInCallable): + from firebase_functions.private._identity_fn import event_type_before_sign_in + + @_functools.wraps(func) + def before_user_signed_in_wrapped(request: _Request) -> _Response: + from firebase_functions.private._identity_fn import before_operation_handler + return before_operation_handler( + func, + event_type_before_sign_in, + request, + ) + + _util.set_func_endpoint_attr( + before_user_signed_in_wrapped, + options._endpoint( + func_name=func.__name__, + event_type=event_type_before_sign_in, + ), + ) + _util.set_required_apis_attr( + before_user_signed_in_wrapped, + options._required_apis(), + ) + return before_user_signed_in_wrapped + + return before_user_signed_in_decorator + + +@_util.copy_func_kwargs(_options.BlockingOptions) +def before_user_created( + **kwargs, +) -> _typing.Callable[[BeforeUserCreatedCallable], BeforeUserCreatedCallable]: + """ + Handles an event that is triggered before a user is created. + + Example: + + .. code-block:: python + + from firebase_functions import identity_fn + + @identity_fn.before_user_created() + def example(event: identity_fn.AuthBlockingEvent) -> identity_fn.BeforeCreateResponse | None: + pass + + :param \\*\\*kwargs: Options. + :type \\*\\*kwargs: as :exc:`firebase_functions.options.BlockingOptions` + :rtype: :exc:`typing.Callable` + \\[ \\[ :exc:`firebase_functions.identity_fn.AuthBlockingEvent` \\], + :exc:`firebase_functions.identity_fn.BeforeCreateResponse` \\| `None` \\] + A function that takes a AuthBlockingEvent and optionally returns BeforeCreateResponse. + """ + options = _options.BlockingOptions(**kwargs) + + def before_user_created_decorator(func: BeforeUserCreatedCallable): + from firebase_functions.private._identity_fn import event_type_before_create + + @_functools.wraps(func) + def before_user_created_wrapped(request: _Request) -> _Response: + from firebase_functions.private._identity_fn import before_operation_handler + return before_operation_handler( + func, + event_type_before_create, + request, + ) + + _util.set_func_endpoint_attr( + before_user_created_wrapped, + options._endpoint( + func_name=func.__name__, + event_type=event_type_before_create, + ), + ) + _util.set_required_apis_attr( + before_user_created_wrapped, + options._required_apis(), + ) + return before_user_created_wrapped + + return before_user_created_decorator diff --git a/src/firebase_functions/options.py b/src/firebase_functions/options.py index 6d05215..128b6d0 100644 --- a/src/firebase_functions/options.py +++ b/src/firebase_functions/options.py @@ -683,6 +683,62 @@ def _endpoint( **_typing.cast(_typing.Dict, kwargs_merged)) +@_dataclasses.dataclass(frozen=True, kw_only=True) +class BlockingOptions(RuntimeOptions): + """ + Options that can be set on an Auth Blocking Trigger. + Internal use only. + """ + + id_token: bool | None = None + """ + Pass the ID Token credential to the function. + """ + + access_token: bool | None = None + """ + Pass the Access Token credential to the function. + """ + + refresh_token: bool | None = None + """ + Pass the Refresh Token credential to the function. + """ + + def _endpoint( + self, + **kwargs, + ) -> _manifest.ManifestEndpoint: + assert kwargs["event_type"] is not None + + blocking_trigger = _manifest.BlockingTrigger( + eventType=kwargs["event_type"], + options=_manifest.BlockingTriggerOptions( + idToken=self.id_token if self.id_token is not None else False, + accessToken=self.access_token + if self.access_token is not None else False, + refreshToken=self.refresh_token + if self.refresh_token is not None else False, + ), + ) + + kwargs_merged = { + **_dataclasses.asdict(super()._endpoint(**kwargs)), + "blockingTrigger": + blocking_trigger, + } + return _manifest.ManifestEndpoint( + **_typing.cast(_typing.Dict, kwargs_merged)) + + def _required_apis(self) -> list[_manifest.ManifestRequiredApi]: + return [ + _manifest.ManifestRequiredApi( + api="identitytoolkit.googleapis.com", + reason="Needed for auth blocking functions", + ) + ] + + @_dataclasses.dataclass(frozen=True, kw_only=True) class FirestoreOptions(RuntimeOptions): """ diff --git a/src/firebase_functions/private/_identity_fn.py b/src/firebase_functions/private/_identity_fn.py new file mode 100644 index 0000000..e301141 --- /dev/null +++ b/src/firebase_functions/private/_identity_fn.py @@ -0,0 +1,344 @@ +# Copyright 2022 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Cloud functions to handle Eventarc events.""" + +# pylint: disable=protected-access +import typing as _typing +import datetime as _dt +import time as _time +import json as _json +from firebase_functions.https_fn import HttpsError, FunctionsErrorCode + +import firebase_functions.private.util as _util +import firebase_functions.private.token_verifier as _token_verifier +from flask import ( + Request as _Request, + Response as _Response, + make_response as _make_response, + jsonify as _jsonify, +) +from functions_framework import logging as _logging + +_claims_max_payload_size = 1000 +_disallowed_custom_claims = [ + "acr", + "amr", + "at_hash", + "aud", + "auth_time", + "azp", + "cnf", + "c_hash", + "exp", + "iat", + "iss", + "jti", + "nbf", + "nonce", + "firebase", +] + + +def _auth_user_info_from_token_data(token_data: dict[str, _typing.Any]): + from firebase_functions.identity_fn import AuthUserInfo + return AuthUserInfo( + uid=token_data["uid"], + provider_id=token_data["provider_id"], + display_name=token_data.get("display_name"), + email=token_data.get("email"), + photo_url=token_data.get("photo_url"), + phone_number=token_data.get("phone_number"), + ) + + +def _auth_user_metadata_from_token_data(token_data: dict[str, _typing.Any]): + from firebase_functions.identity_fn import AuthUserMetadata + return AuthUserMetadata( + creation_time=_dt.datetime.utcfromtimestamp( + token_data["creation_time"] / 1000.0), + last_sign_in_time=_dt.datetime.utcfromtimestamp( + token_data["last_sign_in_time"] / 1000.0), + ) + + +def _auth_multi_factor_info_from_token_data(token_data: dict[str, _typing.Any]): + from firebase_functions.identity_fn import AuthMultiFactorInfo + enrollment_time = token_data.get("enrollment_time") + if enrollment_time: + enrollment_time = _dt.datetime.fromisoformat(enrollment_time) + factor_id = token_data["factor_id"] if not token_data.get( + "phone_number") else "phone" + return AuthMultiFactorInfo( + uid=token_data["uid"], + factor_id=factor_id, + display_name=token_data.get("display_name"), + enrollment_time=enrollment_time, + phone_number=token_data.get("phone_number"), + ) + + +def _auth_multi_factor_settings_from_token_data(token_data: dict[str, + _typing.Any]): + if not token_data: + return None + + from firebase_functions.identity_fn import AuthMultiFactorSettings + + enrolled_factors = [ + _auth_multi_factor_info_from_token_data(factor) + for factor in token_data.get("enrolled_factors", []) + ] + + if not enrolled_factors: + return None + + return AuthMultiFactorSettings(enrolled_factors=enrolled_factors) + + +def _auth_user_record_from_token_data(token_data: dict[str, _typing.Any]): + from firebase_functions.identity_fn import AuthUserRecord + return AuthUserRecord( + uid=token_data["uid"], + email=token_data.get("email"), + email_verified=token_data["email_verified"], + display_name=token_data.get("display_name"), + photo_url=token_data.get("photo_url"), + phone_number=token_data.get("phone_number"), + disabled=token_data.get("disabled", False), + metadata=_auth_user_metadata_from_token_data(token_data["metadata"]), + provider_data=[ + _auth_user_info_from_token_data(info) + for info in token_data["provider_data"] + ], + password_hash=token_data.get("password_hash"), + password_salt=token_data.get("password_salt"), + custom_claims=token_data.get("custom_claims"), + tenant_id=token_data.get("tenant_id"), + tokens_valid_after_time=_dt.datetime.utcfromtimestamp( + token_data["tokens_valid_after_time"]) + if token_data.get("tokens_valid_after_time") else None, + multi_factor=_auth_multi_factor_settings_from_token_data( + token_data["multi_factor"]) + if "multi_factor" in token_data else None, + ) + + +def _additional_user_info_from_token_data(token_data: dict[str, _typing.Any]): + from firebase_functions.identity_fn import AdditionalUserInfo + raw_user_info = token_data.get("raw_user_info") + profile = None + username = None + if raw_user_info: + try: + profile = _json.loads(raw_user_info) + except _json.JSONDecodeError as err: + _logging.debug(f"Parse Error: {err.msg}") + if profile: + sign_in_method = token_data.get("sign_in_method") + if sign_in_method == "github.com": + username = profile.get("login") + elif sign_in_method == "twitter.com": + username = profile.get("screen_name") + + provider_id: str = ("password" + if token_data.get("sign_in_method") == "emailLink" else + str(token_data.get("sign_in_method"))) + + is_new_user = token_data.get("event_type") == "beforeCreate" + + return AdditionalUserInfo( + provider_id=provider_id, + profile=profile, + username=username, + is_new_user=is_new_user, + ) + + +def _credential_from_token_data(token_data: dict[str, _typing.Any], + time: float): + if (not token_data.get("sign_in_attributes") and + not token_data.get("oauth_id_token") and + not token_data.get("oauth_access_token") and + not token_data.get("oauth_refresh_token")): + return None + + from firebase_functions.identity_fn import Credential + + oauth_expires_in = token_data.get("oauth_expires_in") + expiration_time = (_dt.datetime.utcfromtimestamp(time + oauth_expires_in) + if oauth_expires_in else None) + + provider_id: str = ("password" + if token_data.get("sign_in_method") == "emailLink" else + str(token_data.get("sign_in_method"))) + + return Credential( + claims=token_data.get("sign_in_attributes"), + id_token=token_data.get("oauth_id_token"), + access_token=token_data.get("oauth_access_token"), + refresh_token=token_data.get("oauth_refresh_token"), + expiration_time=expiration_time, + secret=token_data.get("oauth_token_secret"), + provider_id=provider_id, + sign_in_method=token_data["sign_in_method"], + ) + + +def _auth_blocking_event_from_token_data(token_data: dict[str, _typing.Any]): + from firebase_functions.identity_fn import AuthBlockingEvent + return AuthBlockingEvent( + data=_auth_user_record_from_token_data(token_data["user_record"]), + locale=token_data.get("locale"), + event_id=token_data["event_id"], + ip_address=token_data["ip_address"], + user_agent=token_data["user_agent"], + timestamp=_dt.datetime.fromtimestamp(token_data["iat"]), + additional_user_info=_additional_user_info_from_token_data(token_data), + credential=_credential_from_token_data(token_data, _time.time()), + ) + + +event_type_before_create = "providers/cloud.auth/eventTypes/user.beforeCreate" +event_type_before_sign_in = "providers/cloud.auth/eventTypes/user.beforeSignIn" + + +def _validate_auth_response( + event_type: str, + auth_response, +) -> dict[str, _typing.Any]: + if auth_response is None: + auth_response = {} + + custom_claims: dict[str, + _typing.Any] | None = auth_response.get("custom_claims") + session_claims: dict[str, _typing.Any] | None = auth_response.get( + "session_claims") + + if session_claims and event_type == event_type_before_create: + raise HttpsError( + FunctionsErrorCode.INVALID_ARGUMENT, + f'The session_claims claims "{",".join(session_claims)}" cannot be specified ' + f"for the before_create event.", + ) + + if custom_claims: + invalid_claims = [ + claim for claim in _disallowed_custom_claims + if claim in custom_claims + ] + + if invalid_claims: + raise HttpsError( + FunctionsErrorCode.INVALID_ARGUMENT, + f'The custom_claims claims "{",".join(invalid_claims)}" are reserved ' + f"and cannot be specified.", + ) + + if len(_json.dumps(custom_claims)) > _claims_max_payload_size: + raise HttpsError( + FunctionsErrorCode.INVALID_ARGUMENT, + f"The custom_claims payload should not exceed " + f"{_claims_max_payload_size} characters.", + ) + + if event_type == event_type_before_sign_in and session_claims: + + invalid_claims = [ + claim for claim in _disallowed_custom_claims + if claim in session_claims + ] + + if invalid_claims: + raise HttpsError( + FunctionsErrorCode.INVALID_ARGUMENT, + f'The session_claims claims "{",".join(invalid_claims)}" are reserved ' + f"and cannot be specified.", + ) + + if len(_json.dumps(session_claims)) > _claims_max_payload_size: + raise HttpsError( + FunctionsErrorCode.INVALID_ARGUMENT, + f"The session_claims payload should not exceed " + f"{_claims_max_payload_size} characters.", + ) + + combined_claims = { + **(custom_claims if custom_claims else {}), + **session_claims + } + + if len(_json.dumps(combined_claims)) > _claims_max_payload_size: + raise HttpsError( + FunctionsErrorCode.INVALID_ARGUMENT, + f"The customClaims and session_claims payloads should not exceed " + f"{_claims_max_payload_size} characters combined.", + ) + + auth_response_dict = {} + auth_response_keys = set(auth_response.keys()) + if "display_name" in auth_response_keys: + auth_response_dict["displayName"] = auth_response["display_name"] + if "disabled" in auth_response_keys: + auth_response_dict["disabled"] = auth_response["disabled"] + if "email_verified" in auth_response_keys: + auth_response_dict["emailVerified"] = auth_response["email_verified"] + if "photo_url" in auth_response_keys: + auth_response_dict["photoURL"] = auth_response["photo_url"] + if "custom_claims" in auth_response_keys: + auth_response_dict["customClaims"] = auth_response["custom_claims"] + if "session_claims" in auth_response_keys: + auth_response_dict["sessionClaims"] = auth_response["session_claims"] + return auth_response_dict + + +def before_operation_handler( + func: _typing.Callable, + event_type: str, + request: _Request, +) -> _Response: + from firebase_functions.identity_fn import BeforeCreateResponse, BeforeSignInResponse + try: + if not _util.valid_on_call_request(request): + _logging.error("Invalid request, unable to process.") + raise HttpsError(FunctionsErrorCode.INVALID_ARGUMENT, "Bad Request") + if request.json is None: + _logging.error("Request is missing body.") + raise HttpsError(FunctionsErrorCode.INVALID_ARGUMENT, "Bad Request") + if request.json is None or "data" not in request.json: + _logging.error("Request body is missing data.", request.json) + raise HttpsError(FunctionsErrorCode.INVALID_ARGUMENT, "Bad Request") + jwt_token = request.json["data"]["jwt"] + decoded_token = _token_verifier.verify_auth_blocking_token(jwt_token) + event = _auth_blocking_event_from_token_data(decoded_token) + auth_response: BeforeCreateResponse | BeforeSignInResponse | None = func( + event) + if not auth_response: + return _jsonify({}) + auth_response_dict = _validate_auth_response(event_type, auth_response) + update_mask = ",".join(auth_response_dict.keys()) + result = { + "userRecord": { + **auth_response_dict, + "updateMask": update_mask, + } + } + return _jsonify(result) + # Disable broad exceptions lint since we want to handle all exceptions. + # pylint: disable=broad-except + except Exception as exception: + if not isinstance(exception, HttpsError): + _logging.error("Unhandled error", exception) + exception = HttpsError(FunctionsErrorCode.INTERNAL, "INTERNAL") + status = exception._http_error_code.status + return _make_response(_jsonify(error=exception._as_dict()), status) diff --git a/src/firebase_functions/private/manifest.py b/src/firebase_functions/private/manifest.py index b923302..864cb22 100644 --- a/src/firebase_functions/private/manifest.py +++ b/src/firebase_functions/private/manifest.py @@ -120,8 +120,15 @@ class ScheduleTrigger(_typing.TypedDict): retryConfig: RetryConfigScheduler | None +class BlockingTriggerOptions(_typing.TypedDict): + accessToken: _typing_extensions.NotRequired[bool] + idToken: _typing_extensions.NotRequired[bool] + refreshToken: _typing_extensions.NotRequired[bool] + + class BlockingTrigger(_typing.TypedDict): eventType: _typing_extensions.Required[str] + options: _typing_extensions.NotRequired[BlockingTriggerOptions] class VpcSettings(_typing.TypedDict): diff --git a/src/firebase_functions/private/token_verifier.py b/src/firebase_functions/private/token_verifier.py new file mode 100644 index 0000000..300da43 --- /dev/null +++ b/src/firebase_functions/private/token_verifier.py @@ -0,0 +1,209 @@ +# Copyright 2022 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Module for internal token verification. +""" +from firebase_admin import _token_gen, exceptions, _auth_utils, initialize_app, get_app, _apps, _DEFAULT_APP_NAME +from google.auth import jwt +import google.auth.exceptions +import google.oauth2.id_token +import google.oauth2.service_account + + +# pylint: disable=consider-using-f-string +# mypy: ignore-errors +# TODO remove once firebase-admin supports this directly. +# Modified from src/firebase_admin/_token_gen.py to add +# support for app_check tokens (expected_audience kwarg and +# usage are new, plus None for audience on google.oauth2.id_token.verify_token call) +class _JWTVerifier: + """Verifies Firebase JWTs (ID tokens or session cookies).""" + + def __init__(self, **kwargs): + self.project_id = kwargs.pop('project_id') + self.short_name = kwargs.pop('short_name') + self.operation = kwargs.pop('operation') + self.url = kwargs.pop('doc_url') + self.cert_url = kwargs.pop('cert_url') + self.issuer = kwargs.pop('issuer') + self.expected_audience = kwargs.pop('expected_audience') + if self.short_name[0].lower() in 'aeiou': + self.articled_short_name = 'an {0}'.format(self.short_name) + else: + self.articled_short_name = 'a {0}'.format(self.short_name) + self._invalid_token_error = kwargs.pop('invalid_token_error') + self._expired_token_error = kwargs.pop('expired_token_error') + + def verify(self, token, request): + """Verifies the signature and data for the provided JWT.""" + token = token.encode('utf-8') if isinstance(token, str) else token + if not isinstance(token, bytes) or not token: + raise ValueError( + 'Illegal {0} provided: {1}. {0} must be a non-empty ' + 'string.'.format(self.short_name, token)) + + if not self.project_id: + raise ValueError( + 'Failed to ascertain project ID from the credential or the environment. Project ' + 'ID is required to call {0}. Initialize the app with a credentials.Certificate ' + 'or set your Firebase project ID as an app option. Alternatively set the ' + 'GOOGLE_CLOUD_PROJECT environment variable.'.format( + self.operation)) + + header, payload = self._decode_unverified(token) + issuer = payload.get('iss') + audience = payload.get('aud') + subject = payload.get('sub') + expected_issuer = self.issuer + self.project_id + + project_id_match_msg = ( + 'Make sure the {0} comes from the same Firebase project as the service account used ' + 'to authenticate this SDK.'.format(self.short_name)) + verify_id_token_msg = ( + 'See {0} for details on how to retrieve {1}.'.format( + self.url, self.short_name)) + + emulated = _auth_utils.is_emulated() + + error_message = None + if audience == _token_gen.FIREBASE_AUDIENCE: + error_message = ('{0} expects {1}, but was given a custom ' + 'token.'.format(self.operation, + self.articled_short_name)) + elif not emulated and not header.get('kid'): + if header.get('alg') == 'HS256' and payload.get( + 'v') == 0 and 'uid' in payload.get('d', {}): + error_message = ( + '{0} expects {1}, but was given a legacy custom ' + 'token.'.format(self.operation, self.articled_short_name)) + else: + error_message = 'Firebase {0} has no "kid" claim.'.format( + self.short_name) + elif not emulated and header.get('alg') != 'RS256': + error_message = ( + 'Firebase {0} has incorrect algorithm. Expected "RS256" but got ' + '"{1}". {2}'.format(self.short_name, header.get('alg'), + verify_id_token_msg)) + elif self.expected_audience and self.expected_audience not in audience: + error_message = ( + 'Firebase {0} has incorrect "aud" (audience) claim. Expected "{1}" but ' + 'got "{2}". {3} {4}'.format(self.short_name, + self.expected_audience, audience, + project_id_match_msg, + verify_id_token_msg)) + elif not self.expected_audience and audience != self.project_id: + error_message = ( + 'Firebase {0} has incorrect "aud" (audience) claim. Expected "{1}" but ' + 'got "{2}". {3} {4}'.format(self.short_name, self.project_id, + audience, project_id_match_msg, + verify_id_token_msg)) + elif issuer != expected_issuer: + error_message = ( + 'Firebase {0} has incorrect "iss" (issuer) claim. Expected "{1}" but ' + 'got "{2}". {3} {4}'.format(self.short_name, expected_issuer, + issuer, project_id_match_msg, + verify_id_token_msg)) + elif subject is None or not isinstance(subject, str): + error_message = ('Firebase {0} has no "sub" (subject) claim. ' + '{1}'.format(self.short_name, verify_id_token_msg)) + elif not subject: + error_message = ( + 'Firebase {0} has an empty string "sub" (subject) claim. ' + '{1}'.format(self.short_name, verify_id_token_msg)) + elif len(subject) > 128: + error_message = ( + 'Firebase {0} has a "sub" (subject) claim longer than 128 characters. ' + '{1}'.format(self.short_name, verify_id_token_msg)) + + if error_message: + raise self._invalid_token_error(error_message) + + try: + if emulated: + verified_claims = payload + else: + verified_claims = google.oauth2.id_token.verify_token( + token, + request=request, + # If expected_audience is set then we have already verified + # the audience above. + audience=(None + if self.expected_audience else self.project_id), + certs_url=self.cert_url) + verified_claims['uid'] = verified_claims['sub'] + return verified_claims + except google.auth.exceptions.TransportError as error: + raise _token_gen.CertificateFetchError(str(error), cause=error) + except ValueError as error: + if 'Token expired' in str(error): + raise self._expired_token_error(str(error), cause=error) + raise self._invalid_token_error(str(error), cause=error) + + def _decode_unverified(self, token): + try: + header = jwt.decode_header(token) + payload = jwt.decode(token, verify=False) + return header, payload + except ValueError as error: + raise self._invalid_token_error(str(error), cause=error) + + +class InvalidAuthBlockingTokenError(exceptions.InvalidArgumentError): + """The provided auth blocking token is not a token.""" + + default_message = 'The provided auth blocking token is invalid' + + def __init__(self, message, cause=None, http_response=None): + exceptions.InvalidArgumentError.__init__(self, message, cause, + http_response) + + +class ExpiredAuthBlockingTokenError(InvalidAuthBlockingTokenError): + """The provided auth blocking token is expired.""" + + def __init__(self, message, cause): + InvalidAuthBlockingTokenError.__init__(self, message, cause) + + +class AuthBlockingTokenVerifier(_token_gen.TokenVerifier): + """Verifies auth blocking tokens.""" + + def __init__(self, app): + super().__init__(app) + self.auth_blocking_token_verifier = _JWTVerifier( + project_id=app.project_id, + short_name='Auth Blocking token', + operation='verify_auth_blocking_token()', + doc_url= + 'https://cloud.google.com/identity-platform/docs/blocking-functions', + cert_url=_token_gen.ID_TOKEN_CERT_URI, + issuer=_token_gen.ID_TOKEN_ISSUER_PREFIX, + invalid_token_error=InvalidAuthBlockingTokenError, + expired_token_error=ExpiredAuthBlockingTokenError, + expected_audience='run.app', # v2 only + ) + + def verify_auth_blocking_token(self, auth_blocking_token): + return self.auth_blocking_token_verifier.verify( + auth_blocking_token, + self.request, + ) + + +def verify_auth_blocking_token(auth_blocking_token): + """Verifies the provided auth blocking token.""" + if _DEFAULT_APP_NAME not in _apps: + initialize_app() + return AuthBlockingTokenVerifier( + get_app()).verify_auth_blocking_token(auth_blocking_token) diff --git a/src/firebase_functions/private/util.py b/src/firebase_functions/private/util.py index 2a831a1..49bf264 100644 --- a/src/firebase_functions/private/util.py +++ b/src/firebase_functions/private/util.py @@ -85,6 +85,16 @@ def valid_on_call_request(request: _Request) -> bool: return False +def convert_keys_to_camel_case( + data: dict[str, _typing.Any]) -> dict[str, _typing.Any]: + + def snake_to_camel(word: str) -> str: + components = word.split("_") + return components[0] + "".join(x.capitalize() for x in components[1:]) + + return {snake_to_camel(key): value for key, value in data.items()} + + def _on_call_valid_body(request: _Request) -> bool: """The body must not be empty.""" if request.json is None: diff --git a/tests/test_options.py b/tests/test_options.py index 686e540..8dc0cba 100644 --- a/tests/test_options.py +++ b/tests/test_options.py @@ -80,15 +80,15 @@ def test_options_preserve_external_changes(): """ Testing if setting a global option internally change the values. """ - assert (options._GLOBAL_OPTIONS.preserve_external_changes is - None), "option should not already be set" + assert (options._GLOBAL_OPTIONS.preserve_external_changes + is None), "option should not already be set" options.set_global_options( preserve_external_changes=False, min_instances=5, ) options_asdict = options._GLOBAL_OPTIONS._asdict_with_global_options() - assert (options_asdict["max_instances"] is - options.RESET_VALUE), "option should be RESET_VALUE" + assert (options_asdict["max_instances"] + is options.RESET_VALUE), "option should be RESET_VALUE" assert options_asdict["min_instances"] == 5, "option should be set" firebase_functions = { diff --git a/tests/test_params.py b/tests/test_params.py index ed62b2d..39284e8 100644 --- a/tests/test_params.py +++ b/tests/test_params.py @@ -27,11 +27,11 @@ def test_bool_param_value_true_or_false(self): for value_true, value_false in zip(["true", "t", "1", "y", "yes"], ["false", "f", "0", "n", "no"]): environ["BOOL_VALUE_TEST1"] = value_true - assert (bool_param.value() is - True), "Failure, params returned False" + assert (bool_param.value() + is True), "Failure, params returned False" environ["BOOL_VALUE_TEST1"] = value_false - assert (bool_param.value() is - False), "Failure, params returned True" + assert (bool_param.value() + is False), "Failure, params returned True" def test_bool_param_value_error(self): """Testing if bool params throws a value error if invalid value.""" @@ -41,26 +41,26 @@ def test_bool_param_value_error(self): def test_bool_param_empty_default(self): """Testing if bool params defaults to False if no value and no default.""" - assert (params.BoolParam("BOOL_DEFAULT_TEST").value() is - False), "Failure, params returned True" + assert (params.BoolParam("BOOL_DEFAULT_TEST").value() + is False), "Failure, params returned True" def test_bool_param_default(self): """Testing if bool params defaults to provided default value.""" assert (params.BoolParam("BOOL_DEFAULT_TEST_FALSE", - default=False).value() is - False), "Failure, params returned True" + default=False).value() + is False), "Failure, params returned True" assert (params.BoolParam("BOOL_DEFAULT_TEST_TRUE", - default=True).value() is - True), "Failure, params returned False" + default=True).value() + is True), "Failure, params returned False" def test_bool_param_equality(self): """Test bool equality.""" assert (params.BoolParam("BOOL_TEST1", - default=False).equals(False).value() is - True), "Failure, equality check returned False" + default=False).equals(False).value() + is True), "Failure, equality check returned False" assert (params.BoolParam("BOOL_TEST2", - default=True).equals(False).value() is - False), "Failure, equality check returned False" + default=True).equals(False).value() + is False), "Failure, equality check returned False" class TestFloatParams: @@ -86,11 +86,11 @@ def test_float_param_default(self): def test_float_param_equality(self): """Test float equality.""" assert (params.FloatParam("FLOAT_TEST1", - default=123.456).equals(123.456).value() is - True), "Failure, equality check returned False" + default=123.456).equals(123.456).value() + is True), "Failure, equality check returned False" assert (params.FloatParam("FLOAT_TEST2", - default=456.789).equals(123.456).value() is - False), "Failure, equality check returned False" + default=456.789).equals(123.456).value() + is False), "Failure, equality check returned False" class TestIntParams: @@ -114,10 +114,10 @@ def test_int_param_default(self): def test_int_param_equality(self): """Test int equality.""" - assert (params.IntParam("INT_TEST1", default=123).equals(123).value() is - True), "Failure, equality check returned False" - assert (params.IntParam("INT_TEST2", default=456).equals(123).value() is - False), "Failure, equality check returned False" + assert (params.IntParam("INT_TEST1", default=123).equals(123).value() + is True), "Failure, equality check returned False" + assert (params.IntParam("INT_TEST2", default=456).equals(123).value() + is False), "Failure, equality check returned False" class TestStringParams: @@ -150,11 +150,11 @@ def test_string_param_default(self): def test_string_param_equality(self): """Test string equality.""" assert (params.StringParam("STRING_TEST1", - default="123").equals("123").value() is - True), "Failure, equality check returned False" + default="123").equals("123").value() + is True), "Failure, equality check returned False" assert (params.StringParam("STRING_TEST2", - default="456").equals("123").value() is - False), "Failure, equality check returned False" + default="456").equals("123").value() + is False), "Failure, equality check returned False" class TestListParams: @@ -185,10 +185,12 @@ def test_list_param_default(self): def test_list_param_equality(self): """Test list equality.""" - assert (params.ListParam("LIST_TEST1", default=["123"]).equals( - ["123"]).value() is True), "Failure, equality check returned False" - assert (params.ListParam("LIST_TEST2", default=["456"]).equals( - ["123"]).value() is False), "Failure, equality check returned False" + assert (params.ListParam("LIST_TEST1", + default=["123"]).equals(["123"]).value() + is True), "Failure, equality check returned False" + assert (params.ListParam("LIST_TEST2", + default=["456"]).equals(["123"]).value() + is False), "Failure, equality check returned False" class TestParamsManifest: