diff --git a/mypy.ini b/mypy.ini index a31e9c8..7712a39 100644 --- a/mypy.ini +++ b/mypy.ini @@ -3,6 +3,7 @@ exclude = build|setup.py|venv # cloudevents package has no types ignore_missing_imports = True +enable_incomplete_feature = Unpack [mypy-yaml.*] ignore_missing_imports = True diff --git a/samples/basic_tasks/.firebaserc b/samples/basic_tasks/.firebaserc new file mode 100644 index 0000000..ad27d4b --- /dev/null +++ b/samples/basic_tasks/.firebaserc @@ -0,0 +1,5 @@ +{ + "projects": { + "default": "python-functions-testing" + } +} diff --git a/samples/basic_tasks/.gitignore b/samples/basic_tasks/.gitignore new file mode 100644 index 0000000..dbb58ff --- /dev/null +++ b/samples/basic_tasks/.gitignore @@ -0,0 +1,66 @@ +# Logs +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +firebase-debug.log* +firebase-debug.*.log* + +# Firebase cache +.firebase/ + +# Firebase config + +# Uncomment this if you'd like others to create their own Firebase project. +# For a team working on the same Firebase project(s), it is recommended to leave +# it commented so all members can deploy to the same project(s) in .firebaserc. +# .firebaserc + +# Runtime data +pids +*.pid +*.seed +*.pid.lock + +# Directory for instrumented libs generated by jscoverage/JSCover +lib-cov + +# Coverage directory used by tools like istanbul +coverage + +# nyc test coverage +.nyc_output + +# Grunt intermediate storage (http://gruntjs.com/creating-plugins#storing-task-files) +.grunt + +# Bower dependency directory (https://bower.io/) +bower_components + +# node-waf configuration +.lock-wscript + +# Compiled binary addons (http://nodejs.org/api/addons.html) +build/Release + +# Dependency directories +node_modules/ + +# Optional npm cache directory +.npm + +# Optional eslint cache +.eslintcache + +# Optional REPL history +.node_repl_history + +# Output of 'npm pack' +*.tgz + +# Yarn Integrity file +.yarn-integrity + +# dotenv environment variables file +.env diff --git a/samples/basic_tasks/__init__.py b/samples/basic_tasks/__init__.py new file mode 100644 index 0000000..2340b04 --- /dev/null +++ b/samples/basic_tasks/__init__.py @@ -0,0 +1,3 @@ +# Required to avoid a 'duplicate modules' mypy error +# in monorepos that have multiple main.py files. +# https://github.com/python/mypy/issues/4008 diff --git a/samples/basic_tasks/firebase.json b/samples/basic_tasks/firebase.json new file mode 100644 index 0000000..7bbd899 --- /dev/null +++ b/samples/basic_tasks/firebase.json @@ -0,0 +1,11 @@ +{ + "functions": [ + { + "source": "functions", + "codebase": "default", + "ignore": [ + "venv" + ] + } + ] +} diff --git a/samples/basic_tasks/functions/.gitignore b/samples/basic_tasks/functions/.gitignore new file mode 100644 index 0000000..34cef6b --- /dev/null +++ b/samples/basic_tasks/functions/.gitignore @@ -0,0 +1,13 @@ +# pyenv +.python-version + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Environments +.env +.venv +venv/ +venv.bak/ +__pycache__ diff --git a/samples/basic_tasks/functions/main.py b/samples/basic_tasks/functions/main.py new file mode 100644 index 0000000..da5f922 --- /dev/null +++ b/samples/basic_tasks/functions/main.py @@ -0,0 +1,69 @@ +"""Firebase Cloud Functions for Tasks.""" + +import datetime +import json + +from firebase_admin import initialize_app +from google.cloud import tasks_v2 +from firebase_functions import tasks_fn, https_fn +from firebase_functions.options import SupportedRegion, RetryConfig, RateLimits + +app = initialize_app() + + +# Once this function is deployed, a Task Queue will be created with the name +# `on_task_dispatched_example`. You can then enqueue tasks to this queue by +# calling the `enqueue_task` function. +@tasks_fn.on_task_dispatched( + retry_config=RetryConfig(max_attempts=5), + rate_limits=RateLimits(max_concurrent_dispatches=10), + region=SupportedRegion.US_CENTRAL1, +) +def ontaskdispatchedexample(req: tasks_fn.CallableRequest): + """ + The endpoint which will be executed by the enqueued task. + """ + print(req.data) + + +# To enqueue a task, you can use the following function. +# e.g. +# curl -X POST -H "Content-Type: application/json" \ +# -d '{"data": "Hello World!"}' \ +# https://enqueue-task--.a.run.app\ +@https_fn.on_request() +def enqueuetask(req: https_fn.Request) -> https_fn.Response: + """ + Enqueues a task to the queue `on_task_dispatched_function`. + """ + client = tasks_v2.CloudTasksClient() + + # The URL of the `on_task_dispatched_function` function. + # Must be set to the URL of the deployed function. + + url = req.json.get("url") if req.json else None + + body = {"data": req.json} + + task: tasks_v2.Task = tasks_v2.Task( + **{ + "http_request": { + "http_method": tasks_v2.HttpMethod.POST, + "url": url, + "headers": { + "Content-type": "application/json" + }, + "body": json.dumps(body).encode(), + }, + "schedule_time": + datetime.datetime.utcnow() + datetime.timedelta(minutes=1), + }) + + parent = client.queue_path( + app.project_id, + SupportedRegion.US_CENTRAL1, + "ontaskdispatchedexample2", + ) + + client.create_task(request={"parent": parent, "task": task}) + return https_fn.Response("Task enqueued.") diff --git a/samples/basic_tasks/functions/requirements.txt b/samples/basic_tasks/functions/requirements.txt new file mode 100644 index 0000000..8977a41 --- /dev/null +++ b/samples/basic_tasks/functions/requirements.txt @@ -0,0 +1,8 @@ +# Not published yet, +# firebase-functions-python >= 0.0.1 +# so we use a relative path during development: +./../../../ +# Or switch to git ref for deployment testing: +# git+https://github.com/firebase/firebase-functions-python.git@main#egg=firebase-functions + +firebase-admin >= 6.0.1 diff --git a/src/firebase_functions/options.py b/src/firebase_functions/options.py index 9e94763..10fe7e2 100644 --- a/src/firebase_functions/options.py +++ b/src/firebase_functions/options.py @@ -99,6 +99,62 @@ class SupportedRegion(str, _enum.Enum): US_WEST1 = "us-west1" +@_dataclasses.dataclass(frozen=True) +class RateLimits(): + """ + How congestion control should be applied to the function. + """ + max_concurrent_dispatches: int | Expression[ + int] | _util.Sentinel | None = None + """ + The maximum number of requests that can be outstanding at a time. + If left unspecified, will default to 1000. + """ + + max_dispatches_per_second: int | Expression[ + int] | _util.Sentinel | None = None + """ + The maximum number of requests that can be invoked per second. + If left unspecified, will default to 500. + """ + + +@_dataclasses.dataclass(frozen=True) +class RetryConfig(): + """ + How a task should be retried in the event of a non-2xx return. + """ + + max_attempts: int | Expression[int] | _util.Sentinel | None = None + """ + The maximum number of times a request should be attempted. + If left unspecified, will default to 3. + """ + + max_retry_seconds: int | Expression[int] | _util.Sentinel | None = None + """ + The maximum amount of time for retrying failed task. + If left unspecified will retry indefinitely. + """ + + max_backoff_seconds: int | Expression[int] | _util.Sentinel | None = None + """ + The maximum amount of time to wait between attempts. + If left unspecified will default to 1hr. + """ + + max_doublings: int | Expression[int] | _util.Sentinel | None = None + """ + The maximum number of times to double the backoff between + retries. If left unspecified will default to 16. + """ + + min_backoff_seconds: int | Expression[int] | _util.Sentinel | None = None + """ + The minimum time to wait between attempts. + """ + + @_dataclasses.dataclass(frozen=True, kw_only=True) class RuntimeOptions: """ @@ -318,6 +374,69 @@ def convert_secret( return endpoint +@_dataclasses.dataclass(frozen=True, kw_only=True) +class TaskQueueOptions(RuntimeOptions): + """ + Options specific to Tasks function types. + """ + + retry_config: RetryConfig | None = None + """ + How a task should be retried in the event of a non-2xx return. + """ + + rate_limits: RateLimits | None = None + """ + How congestion control should be applied to the function. + """ + + invoker: str | list[str] | _typing.Literal["private"] | None = None + """ + Who can enqueue tasks for this function. + + Note: + If left unspecified, only service accounts which have + `roles/cloudtasks.enqueuer` and `roles/cloudfunctions.invoker` + will have permissions. + """ + + def _endpoint( + self, + **kwargs, + ) -> _manifest.ManifestEndpoint: + rate_limits: _manifest.RateLimits | None = _manifest.RateLimits( + maxConcurrentDispatches=self.rate_limits.max_concurrent_dispatches, + maxDispatchesPerSecond=self.rate_limits.max_dispatches_per_second, + ) if self.rate_limits is not None else None + + retry_config: _manifest.RetryConfig | None = _manifest.RetryConfig( + maxAttempts=self.retry_config.max_attempts, + maxRetrySeconds=self.retry_config.max_retry_seconds, + maxBackoffSeconds=self.retry_config.max_backoff_seconds, + maxDoublings=self.retry_config.max_doublings, + minBackoffSeconds=self.retry_config.min_backoff_seconds, + ) if self.retry_config is not None else None + + kwargs_merged = { + **_dataclasses.asdict(super()._endpoint(**kwargs)), + "taskQueueTrigger": + _manifest.TaskQueueTrigger( + rateLimits=rate_limits, + retryConfig=retry_config, + ), + } + return _manifest.ManifestEndpoint( + **_typing.cast(_typing.Dict, kwargs_merged)) + + def _required_apis(self) -> list[_manifest.ManifestRequiredApi]: + return [ + _manifest.ManifestRequiredApi( + api="cloudtasks.googleapis.com", + reason="Needed for task queue functions", + ) + ] + + @_dataclasses.dataclass(frozen=True, kw_only=True) class PubSubOptions(RuntimeOptions): """ diff --git a/src/firebase_functions/private/manifest.py b/src/firebase_functions/private/manifest.py index 03e264b..c2b0140 100644 --- a/src/firebase_functions/private/manifest.py +++ b/src/firebase_functions/private/manifest.py @@ -65,14 +65,38 @@ class EventTrigger(_typing.TypedDict): class RetryConfig(_typing.TypedDict): - retryCount: _typing_extensions.NotRequired[int | _params.Expression[int]] - maxRetrySeconds: _typing_extensions.NotRequired[str | - _params.Expression[str]] - minBackoffSeconds: _typing_extensions.NotRequired[str | - _params.Expression[str]] - maxBackoffSeconds: _typing_extensions.NotRequired[str | - _params.Expression[str]] - maxDoublings: _typing_extensions.NotRequired[int | _params.Expression[int]] + """ + Retry configuration for a endpoint. + """ + maxAttempts: _typing_extensions.NotRequired[int | _params.Expression[int] | + _util.Sentinel | None] + maxRetrySeconds: _typing_extensions.NotRequired[int | + _params.Expression[int] | + _util.Sentinel | None] + maxBackoffSeconds: _typing_extensions.NotRequired[int | + _params.Expression[int] | + _util.Sentinel | None] + maxDoublings: _typing_extensions.NotRequired[int | _params.Expression[int] | + _util.Sentinel | None] + minBackoffSeconds: _typing_extensions.NotRequired[int | + _params.Expression[int] | + _util.Sentinel | None] + + +class RateLimits(_typing.TypedDict): + maxConcurrentDispatches: int | _params.Expression[ + int] | _util.Sentinel | None + + maxDispatchesPerSecond: int | _params.Expression[int] | _util.Sentinel | None + + +class TaskQueueTrigger(_typing.TypedDict): + """ + Trigger definitions for RPCs servers using the HTTP protocol defined at + https://firebase.google.com/docs/functions/callable-reference + """ + retryConfig: RetryConfig | None + rateLimits: RateLimits | None class ScheduleTrigger(_typing.TypedDict): @@ -116,6 +140,7 @@ class ManifestEndpoint: eventTrigger: EventTrigger | None = None scheduleTrigger: ScheduleTrigger | None = None blockingTrigger: BlockingTrigger | None = None + taskQueueTrigger: TaskQueueTrigger | None = None class ManifestRequiredApi(_typing.TypedDict): diff --git a/src/firebase_functions/private/serving.py b/src/firebase_functions/private/serving.py index 4bba831..edc89c8 100644 --- a/src/firebase_functions/private/serving.py +++ b/src/firebase_functions/private/serving.py @@ -65,14 +65,43 @@ def convert_value(obj): return without_nones +def merge_required_apis( + required_apis: list[_manifest.ManifestRequiredApi] +) -> list[_manifest.ManifestRequiredApi]: + api_to_reasons: dict[str, list[str]] = {} + for api_reason in required_apis: + api = api_reason["api"] + reason = api_reason["reason"] + if api not in api_to_reasons: + api_to_reasons[api] = [] + + if reason not in api_to_reasons[api]: + # Append unique reasons only + api_to_reasons[api].append(reason) + + merged: list[_manifest.ManifestRequiredApi] = [] + for api, reasons in api_to_reasons.items(): + merged.append({"api": api, "reason": " ".join(reasons)}) + + return merged + + def functions_as_yaml(functions: dict) -> str: endpoints: dict[str, _manifest.ManifestEndpoint] = {} + required_apis: list[_manifest.ManifestRequiredApi] = [] for name, function in functions.items(): endpoint = function.__firebase_endpoint__ endpoints[name] = endpoint - manifest_stack = _manifest.ManifestStack(endpoints=endpoints, - params=list( - _params._params.values())) + if hasattr(function, "__required_apis"): + for api in function.__required_apis: + required_apis.append(api) + + required_apis = merge_required_apis(required_apis) + manifest_stack = _manifest.ManifestStack( + endpoints=endpoints, + requiredAPIs=required_apis, + params=list(_params._params.values()), + ) manifest_spec = _manifest.manifest_to_spec_dict(manifest_stack) manifest_spec_with_sentinels = to_spec(manifest_spec) diff --git a/src/firebase_functions/private/util.py b/src/firebase_functions/private/util.py index 0b8e34f..2a831a1 100644 --- a/src/firebase_functions/private/util.py +++ b/src/firebase_functions/private/util.py @@ -51,12 +51,22 @@ def return_func(func: _typing.Callable[..., R]) -> _typing.Callable[P, R]: def set_func_endpoint_attr( - func: _typing.Callable[P, _typing.Any], - endpoint: _typing.Any) -> _typing.Callable[P, _typing.Any]: + func: _typing.Callable[P, _typing.Any], + endpoint: _typing.Any, +) -> _typing.Callable[P, _typing.Any]: setattr(func, "__firebase_endpoint__", endpoint) return func +def set_required_apis_attr( + func: _typing.Callable[P, _typing.Any], + required_apis: list, +): + """Set the required APIs for the current function.""" + setattr(func, "__required_apis", required_apis) + return func + + def prune_nones(obj: dict) -> dict: for key in list(obj.keys()): if obj[key] is None: diff --git a/src/firebase_functions/tasks_fn.py b/src/firebase_functions/tasks_fn.py new file mode 100644 index 0000000..63e90fe --- /dev/null +++ b/src/firebase_functions/tasks_fn.py @@ -0,0 +1,68 @@ +# Copyright 2022 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Cloud functions to handle Tasks enqueued with Google Cloud Tasks.""" + +# pylint: disable=protected-access +import typing as _typing +import functools as _functools + +from flask import Request, Response + +import firebase_functions.options as _options +import firebase_functions.private.util as _util +from firebase_functions.https_fn import CallableRequest, _on_call_handler + +_C = _typing.Callable[[CallableRequest[_typing.Any]], _typing.Any] + + +@_util.copy_func_kwargs(_options.TaskQueueOptions) +def on_task_dispatched(**kwargs) -> _typing.Callable[[_C], Response]: + """ + Creates a handler for tasks sent to a Google Cloud Tasks queue. + Requires a function that takes a CallableRequest. + + Example: + + .. code-block:: python + + @tasks.on_task_dispatched() + def example(request: tasks.CallableRequest) -> Any: + return "Hello World" + + :param \\*\\*kwargs: TaskQueueOptions options. + :type \\*\\*kwargs: as :exc:`firebase_functions.options.TaskQueueOptions` + :rtype: :exc:`typing.Callable` + \\[ \\[ :exc:`firebase_functions.https.CallableRequest` \\[ + :exc:`object` \\] \\], :exc:`object` \\] + A function that takes a CallableRequest and returns an :exc:`object`. + """ + options = _options.TaskQueueOptions(**kwargs) + + def on_task_dispatched_decorator(func: _C): + + @_functools.wraps(func) + def on_task_dispatched_wrapped(request: Request) -> Response: + return _on_call_handler(func, request, enforce_app_check=False) + + _util.set_func_endpoint_attr( + on_task_dispatched_wrapped, + options._endpoint(func_name=func.__name__), + ) + _util.set_required_apis_attr( + on_task_dispatched_wrapped, + options._required_apis(), + ) + return on_task_dispatched_wrapped + + return on_task_dispatched_decorator diff --git a/tests/test_options.py b/tests/test_options.py index 22ce452..686e540 100644 --- a/tests/test_options.py +++ b/tests/test_options.py @@ -16,7 +16,7 @@ """ from firebase_functions import options, https_fn from firebase_functions import params -from firebase_functions.private.serving import functions_as_yaml +from firebase_functions.private.serving import functions_as_yaml, merge_required_apis # pylint: disable=protected-access @@ -106,3 +106,114 @@ def test_options_preserve_external_changes(): yaml = functions_as_yaml(firebase_functions2) assert " availableMemoryMb: null\n" not in yaml, "availableMemoryMb found in yaml" assert " serviceAccountEmail: null\n" not in yaml, "serviceAccountEmail found in yaml" + + +def test_merge_apis_empty_input(): + """ + This test checks the behavior of the merge_required_apis function + when the input is an empty list. The desired outcome for this test + is to receive an empty list as output. This test ensures that the + function can handle the situation where there are no input APIs to merge. + """ + required_apis = [] + expected_output = [] + merged_apis = merge_required_apis(required_apis) + + assert merged_apis == expected_output, f"Expected {expected_output}, but got {merged_apis}" + + +def test_merge_apis_no_duplicate_apis(): + """ + This test verifies that the merge_required_apis function functions + correctly when the input is a list of unique APIs with no duplicates. + The expected result is a list containing the same unique APIs in the + input list. This test confirms that the function processes and returns + APIs without modification when there is no duplication. + """ + required_apis = [ + { + "api": "API1", + "reason": "Reason 1" + }, + { + "api": "API2", + "reason": "Reason 2" + }, + { + "api": "API3", + "reason": "Reason 3" + }, + ] + + expected_output = [ + { + "api": "API1", + "reason": "Reason 1" + }, + { + "api": "API2", + "reason": "Reason 2" + }, + { + "api": "API3", + "reason": "Reason 3" + }, + ] + + merged_apis = merge_required_apis(required_apis) + + assert merged_apis == expected_output, f"Expected {expected_output}, but got {merged_apis}" + + +def test_merge_apis_duplicate_apis(): + """ + This test evaluates the merge_required_apis function when the + input list contains duplicate APIs with different reasons. + The desired outcome for this test is a list where the duplicate + APIs are merged properly and reasons are combined. + This test ensures that the function correctly merges the duplicate + APIs and combines the reasons associated with them. + """ + required_apis = [ + { + "api": "API1", + "reason": "Reason 1" + }, + { + "api": "API2", + "reason": "Reason 2" + }, + { + "api": "API1", + "reason": "Reason 3" + }, + { + "api": "API2", + "reason": "Reason 4" + }, + ] + + expected_output = [ + { + "api": "API1", + "reason": "Reason 1 Reason 3" + }, + { + "api": "API2", + "reason": "Reason 2 Reason 4" + }, + ] + + merged_apis = merge_required_apis(required_apis) + + assert len(merged_apis) == len( + expected_output + ), f"Expected a list of length {len(expected_output)}, but got {len(merged_apis)}" + + for expected_item in expected_output: + assert (expected_item in merged_apis + ), f"Expected item {expected_item} missing from the merged list" + + for actual_item in merged_apis: + assert (actual_item in expected_output + ), f"Unexpected item {actual_item} found in the merged list" diff --git a/tests/test_tasks_fn.py b/tests/test_tasks_fn.py new file mode 100644 index 0000000..0e1293a --- /dev/null +++ b/tests/test_tasks_fn.py @@ -0,0 +1,70 @@ +# Copyright 2022 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Task Queue function tests.""" +import unittest + +from unittest.mock import MagicMock +from flask import Flask, Request +from werkzeug.test import EnvironBuilder +from firebase_functions.tasks_fn import on_task_dispatched, CallableRequest + + +class TestTasks(unittest.TestCase): + """ + Task Queue function tests. + """ + + def test_on_task_dispatched_decorator(self): + """ + Tests the on_task_dispatched decorator functionality by checking + that the __firebase_endpoint__ attribute is set properly. + """ + + func = MagicMock() + func.__name__ = "testfn" + decorated_func = on_task_dispatched()(func) + endpoint = getattr(decorated_func, "__firebase_endpoint__") + self.assertIsNotNone(endpoint) + self.assertIsNotNone(endpoint.taskQueueTrigger) + + def test_task_handler(self): + """ + Test the proper execution of the task handler created by the on_task_dispatched + decorator. This test will create a Flask app, apply the on_task_dispatched + decorator to the example function, inject a request, and then ensure that a + correct response is generated. + """ + app = Flask(__name__) + + @on_task_dispatched() + def example(request: CallableRequest[object]) -> str: + self.assertEqual(request.data, {"test": "value"}) + return "Hello World" + + with app.test_request_context("/"): + environ = EnvironBuilder( + method="POST", + json={ + "data": { + "test": "value" + }, + }, + ).get_environ() + request = Request(environ) + response = example(request) + self.assertEqual(response.status_code, 200) + self.assertEqual( + response.get_data(as_text=True), + '{"result":"Hello World"}\n', + )