From cae8f3370b6652fc8652a42a54f36c8b080657f5 Mon Sep 17 00:00:00 2001 From: Jonathan Edey Date: Mon, 7 Apr 2025 10:55:08 -0400 Subject: [PATCH 1/5] httpx async_send_each prototype --- firebase_admin/_utils.py | 87 ++++++++++++++ firebase_admin/messaging.py | 179 +++++++++++++++++++++++++++-- integration/conftest.py | 26 +++-- integration/test_messaging.py | 84 ++++++++++++++ requirements.txt | 4 +- tests/test_messaging.py | 206 ++++++++++++++++++++++++++++++++++ 6 files changed, 566 insertions(+), 20 deletions(-) diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index b6e292546..f35e7cf65 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -16,9 +16,11 @@ import json from platform import python_version +from typing import Callable, Optional, Union import google.auth import requests +import httpx import firebase_admin from firebase_admin import exceptions @@ -128,6 +130,37 @@ def handle_platform_error_from_requests(error, handle_func=None): return exc if exc else _handle_func_requests(error, message, error_dict) +def handle_platform_error_from_httpx( + error: httpx.HTTPError, + handle_func: Optional[Callable[...,Optional[exceptions.FirebaseError]]] = None +) -> exceptions.FirebaseError: + """Constructs a ``FirebaseError`` from the given httpx error. + + This can be used to handle errors returned by Google Cloud Platform (GCP) APIs. + + Args: + error: An error raised by the httpx module while making an HTTP call to a GCP API. + handle_func: A function that can be used to handle platform errors in a custom way. When + specified, this function will be called with three arguments. It has the same + signature as ```_handle_func_httpx``, but may return ``None``. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code. + """ + + if isinstance(error, httpx.HTTPStatusError): + response = error.response + content = response.content.decode() + status_code = response.status_code + error_dict, message = _parse_platform_error(content, status_code) + exc = None + if handle_func: + exc = handle_func(error, message, error_dict) + + return exc if exc else _handle_func_httpx(error, message, error_dict) + else: + return handle_httpx_error(error) + def handle_operation_error(error): """Constructs a ``FirebaseError`` from the given operation error. @@ -204,6 +237,60 @@ def handle_requests_error(error, message=None, code=None): err_type = _error_code_to_exception_type(code) return err_type(message=message, cause=error, http_response=error.response) +def _handle_func_httpx(error: httpx.HTTPError, message, error_dict) -> exceptions.FirebaseError: + """Constructs a ``FirebaseError`` from the given GCP error. + + Args: + error: An error raised by the httpx module while making an HTTP call. + message: A message to be included in the resulting ``FirebaseError``. + error_dict: Parsed GCP error response. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code or None. + """ + code = error_dict.get('status') + return handle_httpx_error(error, message, code) + + +def handle_httpx_error(error: httpx.HTTPError, message=None, code=None) -> exceptions.FirebaseError: + """Constructs a ``FirebaseError`` from the given httpx error. + + This method is agnostic of the remote service that produced the error, whether it is a GCP + service or otherwise. Therefore, this method does not attempt to parse the error response in + any way. + + Args: + error: An error raised by the httpx module while making an HTTP call. + message: A message to be included in the resulting ``FirebaseError`` (optional). If not + specified the string representation of the ``error`` argument is used as the message. + code: A GCP error code that will be used to determine the resulting error type (optional). + If not specified the HTTP status code on the error response is used to determine a + suitable error code. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code. + """ + if isinstance(error, httpx.TimeoutException): + return exceptions.DeadlineExceededError( + message='Timed out while making an API call: {0}'.format(error), + cause=error) + if isinstance(error, httpx.ConnectError): + return exceptions.UnavailableError( + message='Failed to establish a connection: {0}'.format(error), + cause=error) + if isinstance(error, httpx.HTTPStatusError): + print("printing status error", error) + if not code: + code = _http_status_to_error_code(error.response.status_code) + if not message: + message = str(error) + + err_type = _error_code_to_exception_type(code) + return err_type(message=message, cause=error, http_response=error.response) + + return exceptions.UnknownError( + message='Unknown error while making a remote service call: {0}'.format(error), + cause=error) def _http_status_to_error_code(status): """Maps an HTTP status to a platform error code.""" diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index d2ad04a04..6ce839c89 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -14,11 +14,16 @@ """Firebase Cloud Messaging module.""" +from __future__ import annotations +from typing import Callable, List, Optional, TypeVar import concurrent.futures import json import warnings import requests +import httpx +import asyncio +from google.auth import credentials, transport from googleapiclient import http from googleapiclient import _auth @@ -62,6 +67,7 @@ 'WebpushNotification', 'WebpushNotificationAction', + 'async_send_each' 'send', 'send_all', 'send_multicast', @@ -69,8 +75,9 @@ 'send_each_for_multicast', 'subscribe_to_topic', 'unsubscribe_from_topic', -] +] # type: ignore +TFirebaseError = TypeVar('TFirebaseError', bound=exceptions.FirebaseError) AndroidConfig = _messaging_utils.AndroidConfig AndroidFCMOptions = _messaging_utils.AndroidFCMOptions @@ -97,7 +104,7 @@ UnregisteredError = _messaging_utils.UnregisteredError -def _get_messaging_service(app): +def _get_messaging_service(app) -> _MessagingService: return _utils.get_app_service(app, _MESSAGING_ATTRIBUTE, _MessagingService) def send(message, dry_run=False, app=None): @@ -140,6 +147,9 @@ def send_each(messages, dry_run=False, app=None): """ return _get_messaging_service(app).send_each(messages, dry_run) +async def async_send_each(messages, dry_run=True, app: firebase_admin.App | None = None) -> BatchResponse: + return await _get_messaging_service(app).async_send_each(messages, dry_run) + def send_each_for_multicast(multicast_message, dry_run=False, app=None): """Sends the given mutlicast message to each token via Firebase Cloud Messaging (FCM). @@ -321,21 +331,21 @@ def errors(self): class BatchResponse: """The response received from a batch request to the FCM API.""" - def __init__(self, responses): + def __init__(self, responses: List[SendResponse]) -> None: self._responses = responses self._success_count = len([resp for resp in responses if resp.success]) @property - def responses(self): + def responses(self) -> List[SendResponse]: """A list of ``messaging.SendResponse`` objects (possibly empty).""" return self._responses @property - def success_count(self): + def success_count(self) -> int: return self._success_count @property - def failure_count(self): + def failure_count(self) -> int: return len(self.responses) - self.success_count @@ -363,6 +373,51 @@ def exception(self): """A ``FirebaseError`` if an error occurs while sending the message to the FCM service.""" return self._exception +# Auth Flow +# The aim here is to be able to get auth credentials right before the request is sent. +# This is similar to what is done in transport.requests.AuthorizedSession(). +# We can then pass this in at the client level. +class CustomGoogleAuth(httpx.Auth): + def __init__(self, credentials: credentials.Credentials): + self._credential = credentials + self._max_refresh_attempts = 2 + self._refresh_status_codes = (401,) + + def apply_auth_headers(self, request: httpx.Request): + # Build request used to refresh credentials if needed + auth_request = transport.requests.Request() # type: ignore + # This refreshes the credentials if needed and mutates the request headers to contain access token + # and any other google auth headers + self._credential.before_request(auth_request, request.method, request.url, request.headers) + + + def auth_flow(self, request: httpx.Request): + # Keep original headers since `credentials.before_request` mutates the passed headers and we + # want to keep the original in cause we need an auth retry. + _original_headers = request.headers.copy() + + _credential_refresh_attempt = 0 + while ( + _credential_refresh_attempt < self._max_refresh_attempts + ): + # copy original headers + request.headers = _original_headers.copy() + # mutates request headers + self.apply_auth_headers(request) + + # Continue to perform the request + # yield here dispatches the request and returns with the response + response: httpx.Response = yield request + + # We can check the result of the response and determine in we need to retry on refreshable status codes. + # Current transport.requests.AuthorizedSession() only does this on 401 errors. We should do the same. + if response.status_code in self._refresh_status_codes: + _credential_refresh_attempt += 1 + print(response.status_code, response.reason_phrase, _credential_refresh_attempt) + else: + break; + + class _MessagingService: """Service class that implements Firebase Cloud Messaging (FCM) functionality.""" @@ -381,7 +436,7 @@ class _MessagingService: 'UNREGISTERED': UnregisteredError, } - def __init__(self, app): + def __init__(self, app) -> None: project_id = app.project_id if not project_id: raise ValueError( @@ -396,6 +451,12 @@ def __init__(self, app): timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) self._credential = app.credential.get_credential() self._client = _http_client.JsonHttpClient(credential=self._credential, timeout=timeout) + self._async_client = httpx.AsyncClient( + http2=True, + auth=CustomGoogleAuth(self._credential), + timeout=timeout, + transport=HttpxRetryTransport() + ) self._build_transport = _auth.authorized_http @classmethod @@ -423,7 +484,7 @@ def send_each(self, messages, dry_run=False): """Sends the given messages to FCM via the FCM v1 API.""" if not isinstance(messages, list): raise ValueError('messages must be a list of messaging.Message instances.') - if len(messages) > 500: + if len(messages) > 1000: raise ValueError('messages must not contain more than 500 elements.') def send_data(data): @@ -448,6 +509,40 @@ def send_data(data): message='Unknown error while making remote service calls: {0}'.format(error), cause=error) + async def async_send_each(self, messages: List[Message], dry_run: bool = True) -> BatchResponse: + """Sends the given messages to FCM via the FCM v1 API.""" + if not isinstance(messages, list): + raise ValueError('messages must be a list of messaging.Message instances.') + if len(messages) > 1000: + raise ValueError('messages must not contain more than 500 elements.') + + async def send_data(data): + try: + resp = await self._async_client.request( + 'post', + url=self._fcm_url, + headers=self._fcm_headers, + json=data) + # HTTP/2 check + if resp.http_version != 'HTTP/2': + raise Exception('This messages was not sent with HTTP/2') + resp.raise_for_status() + # except httpx.HTTPStatusError as exception: + except httpx.HTTPError as exception: + return SendResponse(resp=None, exception=self._handle_fcm_httpx_error(exception)) + else: + return SendResponse(resp.json(), exception=None) + + message_data = [self._message_data(message, dry_run) for message in messages] + try: + responses = await asyncio.gather(*[send_data(message) for message in message_data]) + return BatchResponse(responses) + except Exception as error: + raise exceptions.UnknownError( + message='Unknown error while making remote service calls: {0}'.format(error), + cause=error) + + def send_all(self, messages, dry_run=False): """Sends the given messages to FCM via the batch API.""" if not isinstance(messages, list): @@ -533,6 +628,11 @@ def _handle_fcm_error(self, error): return _utils.handle_platform_error_from_requests( error, _MessagingService._build_fcm_error_requests) + def _handle_fcm_httpx_error(self, error: httpx.HTTPError) -> exceptions.FirebaseError: + """Handles errors received from the FCM API.""" + return _utils.handle_platform_error_from_httpx( + error, _MessagingService._build_fcm_error_httpx) + def _handle_iid_error(self, error): """Handles errors received from the Instance ID API.""" if error.response is None: @@ -561,6 +661,13 @@ def _handle_batch_error(self, error): """Handles errors received from the googleapiclient while making batch requests.""" return _gapic_utils.handle_platform_error_from_googleapiclient( error, _MessagingService._build_fcm_error_googleapiclient) + + # We should be careful to clean up the httpx clients. + # Since we are using an async client we must also close in async. However we can sync wrap this. + # The close method is called by the app on shutdown/clean-up of each service. We don't seem to + # make use of this much elsewhere. + def close(self) -> None: + asyncio.run(self._async_client.aclose()) @classmethod def _build_fcm_error_requests(cls, error, message, error_dict): @@ -569,6 +676,17 @@ def _build_fcm_error_requests(cls, error, message, error_dict): exc_type = cls._build_fcm_error(error_dict) return exc_type(message, cause=error, http_response=error.response) if exc_type else None + @classmethod + def _build_fcm_error_httpx(cls, error: httpx.HTTPError, message, error_dict) -> Optional[exceptions.FirebaseError]: + """Parses a httpx error response from the FCM API and creates a FCM-specific exception if + appropriate.""" + exc_type = cls._build_fcm_error(error_dict) + if isinstance(error, httpx.HTTPStatusError): + return exc_type(message, cause=error, http_response=error.response) if exc_type else None + else: + return exc_type(message, cause=error) if exc_type else None + + @classmethod def _build_fcm_error_googleapiclient(cls, error, message, error_dict, http_response): """Parses an error response from the FCM API and creates a FCM-specific exception if @@ -577,7 +695,7 @@ def _build_fcm_error_googleapiclient(cls, error, message, error_dict, http_respo return exc_type(message, cause=error, http_response=http_response) if exc_type else None @classmethod - def _build_fcm_error(cls, error_dict): + def _build_fcm_error(cls, error_dict) -> Optional[Callable[..., exceptions.FirebaseError]]: if not error_dict: return None fcm_code = None @@ -585,4 +703,45 @@ def _build_fcm_error(cls, error_dict): if detail.get('@type') == 'type.googleapis.com/google.firebase.fcm.v1.FcmError': fcm_code = detail.get('errorCode') break - return _MessagingService.FCM_ERROR_TYPES.get(fcm_code) + return _MessagingService.FCM_ERROR_TYPES.get(fcm_code) if fcm_code else None + + +class HttpxRetryTransport(httpx.AsyncBaseTransport): + # We could also support passing kwargs here + def __init__(self) -> None: + self._retryable_status_codes = (500, 503,) + self._max_retry_count = 4 + + # We should use a full AsyncHTTPTransport under the hood since that is + # fully implemented. We could consider making this class extend a + # AsyncHTTPTransport instead and use the parent class's methods to handle + # requests. We sould also ensure that that transport's internal retry is + # not enabled. + self._wrapped_transport = httpx.AsyncHTTPTransport(retries=0, http2=True) + + # Checklist: + # - Do we want to disable built in retries + # - Can we dispatch the same request multiple times? Is there any side effects? + + # Two types of retries + # - Status code (500s, redirect) + # - Error code (read, connect, other) + # - more ??? + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + _retry_count = 0 + + while True: + # Dispatch request + response = await self._wrapped_transport.handle_async_request(request) + + # Check if request is retryable + if response.status_code in self._retryable_status_codes: + _retry_count += 1 + + # Figure out how we want to handle 0 here + if _retry_count > self._max_retry_count: + return response + else: + return response + # break; \ No newline at end of file diff --git a/integration/conftest.py b/integration/conftest.py index 71f53f612..ee9509017 100644 --- a/integration/conftest.py +++ b/integration/conftest.py @@ -15,8 +15,9 @@ """pytest configuration and global fixtures for integration tests.""" import json -import asyncio +# import asyncio import pytest +from pytest_asyncio import is_async_test import firebase_admin from firebase_admin import credentials @@ -72,11 +73,18 @@ def api_key(request): with open(path) as keyfile: return keyfile.read().strip() -@pytest.fixture(scope="session") -def event_loop(): - """Create an instance of the default event loop for test session. - This avoids early eventloop closure. - """ - loop = asyncio.get_event_loop_policy().new_event_loop() - yield loop - loop.close() +# @pytest.fixture(scope="session") +# def event_loop(): +# """Create an instance of the default event loop for test session. +# This avoids early eventloop closure. +# """ +# loop = asyncio.get_event_loop_policy().new_event_loop() +# yield loop +# loop.close() + +# +def pytest_collection_modifyitems(items): + pytest_asyncio_tests = (item for item in items if is_async_test(item)) + session_scope_marker = pytest.mark.asyncio(loop_scope="session") + for async_test in pytest_asyncio_tests: + async_test.add_marker(session_scope_marker, append=False) \ No newline at end of file diff --git a/integration/test_messaging.py b/integration/test_messaging.py index 4c1d7d0dc..dfd2a8684 100644 --- a/integration/test_messaging.py +++ b/integration/test_messaging.py @@ -14,6 +14,7 @@ """Integration tests for firebase_admin.messaging module.""" +import asyncio import re from datetime import datetime @@ -221,3 +222,86 @@ def test_subscribe(): def test_unsubscribe(): resp = messaging.unsubscribe_from_topic(_REGISTRATION_TOKEN, 'mock-topic') assert resp.success_count + resp.failure_count == 1 + +@pytest.mark.asyncio +async def test_async_send_each(): + messages = [ + messaging.Message( + topic='foo-bar', notification=messaging.Notification('Title', 'Body')), + messaging.Message( + topic='foo-bar', notification=messaging.Notification('Title', 'Body')), + messaging.Message( + token='not-a-token', notification=messaging.Notification('Title', 'Body')), + ] + + batch_response = await messaging.async_send_each(messages, dry_run=True) + + assert batch_response.success_count == 2 + assert batch_response.failure_count == 1 + assert len(batch_response.responses) == 3 + + response = batch_response.responses[0] + assert response.success is True + assert response.exception is None + assert re.match('^projects/.*/messages/.*$', response.message_id) + + response = batch_response.responses[1] + assert response.success is True + assert response.exception is None + assert re.match('^projects/.*/messages/.*$', response.message_id) + + response = batch_response.responses[2] + assert response.success is False + assert isinstance(response.exception, exceptions.InvalidArgumentError) + assert response.message_id is None + + +# @pytest.mark.asyncio +# async def test_async_send_each_error(): +# messages = [ +# messaging.Message( +# topic='foo-bar', notification=messaging.Notification('Title', 'Body')), +# messaging.Message( +# topic='foo-bar', notification=messaging.Notification('Title', 'Body')), +# messaging.Message( +# token='not-a-token', notification=messaging.Notification('Title', 'Body')), +# ] + +# batch_response = await messaging.async_send_each(messages, dry_run=True) + +# assert batch_response.success_count == 2 +# assert batch_response.failure_count == 1 +# assert len(batch_response.responses) == 3 + +# response = batch_response.responses[0] +# assert response.success is True +# assert response.exception is None +# assert re.match('^projects/.*/messages/.*$', response.message_id) + +# response = batch_response.responses[1] +# assert response.success is True +# assert response.exception is None +# assert re.match('^projects/.*/messages/.*$', response.message_id) + +# response = batch_response.responses[2] +# assert response.success is False +# assert isinstance(response.exception, exceptions.InvalidArgumentError) +# assert response.message_id is None + +@pytest.mark.asyncio +async def test_async_send_each_500(): + messages = [] + for msg_number in range(500): + topic = 'foo-bar-{0}'.format(msg_number % 10) + messages.append(messaging.Message(topic=topic)) + + batch_response = await messaging.async_send_each(messages, dry_run=True) + + assert batch_response.success_count == 500 + assert batch_response.failure_count == 0 + assert len(batch_response.responses) == 500 + for response in batch_response.responses: + assert response.success is True + assert response.exception is None + assert re.match('^projects/.*/messages/.*$', response.message_id) + diff --git a/requirements.txt b/requirements.txt index fd5b0b39c..c662dd53a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,10 +5,12 @@ pytest-cov >= 2.4.0 pytest-localserver >= 0.4.1 pytest-asyncio >= 0.16.0 pytest-mock >= 3.6.1 +respx == 0.22.0 cachecontrol >= 0.12.14 google-api-core[grpc] >= 1.22.1, < 3.0.0dev; platform.python_implementation != 'PyPy' google-api-python-client >= 1.7.8 google-cloud-firestore >= 2.19.0; platform.python_implementation != 'PyPy' google-cloud-storage >= 1.37.1 -pyjwt[crypto] >= 2.5.0 \ No newline at end of file +pyjwt[crypto] >= 2.5.0 +httpx == 0.28.1 \ No newline at end of file diff --git a/tests/test_messaging.py b/tests/test_messaging.py index b7b5c69ba..f4ce943e4 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -14,12 +14,15 @@ """Test cases for the firebase_admin.messaging module.""" import datetime +from itertools import chain, repeat import json import numbers +import respx from googleapiclient import http from googleapiclient import _helpers import pytest +import httpx import firebase_admin from firebase_admin import exceptions @@ -1924,6 +1927,209 @@ def test_send_each(self): assert all([r.success for r in batch_response.responses]) assert not any([r.exception for r in batch_response.responses]) + @respx.mock + @pytest.mark.asyncio + async def test_async_send_each(self): + responses = [ + respx.MockResponse(200, http_version='HTTP/2', json={'name': 'message-id1'}), + respx.MockResponse(200, http_version='HTTP/2', json={'name': 'message-id2'}), + respx.MockResponse(200, http_version='HTTP/2', json={'name': 'message-id3'}), + ] + msg1 = messaging.Message(topic='foo1') + msg2 = messaging.Message(topic='foo2') + msg3 = messaging.Message(topic='foo3') + route = respx.request('POST', 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send').mock(side_effect=responses) + + batch_response = await messaging.async_send_each([msg1, msg2, msg3], dry_run=True) + + # try: + # batch_response = await messaging.async_send_each([msg1, msg2], dry_run=True) + # except Exception as error: + # if isinstance(error.cause.__cause__, StopIteration): + # raise Exception('Received more requests than mocks') + + assert batch_response.success_count == 3 + assert batch_response.failure_count == 0 + assert len(batch_response.responses) == 3 + assert [r.message_id for r in batch_response.responses] == ['message-id1', 'message-id2', 'message-id3'] + assert all([r.success for r in batch_response.responses]) + assert not any([r.exception for r in batch_response.responses]) + + assert route.call_count == 3 + + @respx.mock + @pytest.mark.asyncio + async def test_async_send_each_error_401_fail_auth_retry(self): + payload = json.dumps({ + 'error': { + 'status': 'UNAUTHENTICATED', + 'message': 'test unauthenticated error', + 'details': [ + { + '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', + 'errorCode': 'SOME_UNKNOWN_CODE', + }, + ], + } + }) + + responses = repeat(respx.MockResponse(401, http_version='HTTP/2', content=payload)) + + msg1 = messaging.Message(topic='foo1') + route = respx.request('POST', 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send').mock(side_effect=responses) + batch_response = await messaging.async_send_each([msg1], dry_run=True) + + assert route.call_count == 2 + assert batch_response.success_count == 0 + assert batch_response.failure_count == 1 + assert len(batch_response.responses) == 1 + exception = batch_response.responses[0].exception + assert isinstance(exception, exceptions.UnauthenticatedError) + + @respx.mock + @pytest.mark.asyncio + async def test_async_send_each_error_401_pass_on_auth_retry(self): + payload = json.dumps({ + 'error': { + 'status': 'UNAUTHENTICATED', + 'message': 'test unauthenticated error', + 'details': [ + { + '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', + 'errorCode': 'SOME_UNKNOWN_CODE', + }, + ], + } + }) + responses = [ + respx.MockResponse(401, http_version='HTTP/2', content=payload), + respx.MockResponse(200, http_version='HTTP/2', json={'name': 'message-id1'}), + ] + + msg1 = messaging.Message(topic='foo1') + route = respx.request('POST', 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send').mock(side_effect=responses) + batch_response = await messaging.async_send_each([msg1], dry_run=True) + + assert route.call_count == 2 + assert batch_response.success_count == 1 + assert batch_response.failure_count == 0 + assert len(batch_response.responses) == 1 + assert [r.message_id for r in batch_response.responses] == ['message-id1'] + assert all([r.success for r in batch_response.responses]) + assert not any([r.exception for r in batch_response.responses]) + + @respx.mock + @pytest.mark.asyncio + async def test_async_send_each_error_500_fail_retry_config(self): + payload = json.dumps({ + 'error': { + 'status': 'INTERNAL', + 'message': 'test INTERNAL error', + 'details': [ + { + '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', + 'errorCode': 'SOME_UNKNOWN_CODE', + }, + ], + } + }) + + responses = repeat(respx.MockResponse(500, http_version='HTTP/2', content=payload)) + + msg1 = messaging.Message(topic='foo1') + route = respx.request('POST', 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send').mock(side_effect=responses) + batch_response = await messaging.async_send_each([msg1], dry_run=True) + + assert route.call_count == 5 + assert batch_response.success_count == 0 + assert batch_response.failure_count == 1 + assert len(batch_response.responses) == 1 + exception = batch_response.responses[0].exception + assert isinstance(exception, exceptions.InternalError) + + + @respx.mock + @pytest.mark.asyncio + async def test_async_send_each_error_500_pass_on_retry_config(self): + payload = json.dumps({ + 'error': { + 'status': 'INTERNAL', + 'message': 'test INTERNAL error', + 'details': [ + { + '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', + 'errorCode': 'SOME_UNKNOWN_CODE', + }, + ], + } + }) + responses = chain( + [ + respx.MockResponse(500, http_version='HTTP/2', content=payload), + respx.MockResponse(500, http_version='HTTP/2', content=payload), + respx.MockResponse(500, http_version='HTTP/2', content=payload), + respx.MockResponse(500, http_version='HTTP/2', content=payload), + respx.MockResponse(200, http_version='HTTP/2', json={'name': 'message-id1'}), + ], + ) + + msg1 = messaging.Message(topic='foo1') + route = respx.request('POST', 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send').mock(side_effect=responses) + batch_response = await messaging.async_send_each([msg1], dry_run=True) + + assert route.call_count == 5 + assert batch_response.success_count == 1 + assert batch_response.failure_count == 0 + assert len(batch_response.responses) == 1 + assert [r.message_id for r in batch_response.responses] == ['message-id1'] + assert all([r.success for r in batch_response.responses]) + assert not any([r.exception for r in batch_response.responses]) + + # @respx.mock + # @pytest.mark.asyncio + # async def test_async_send_each_error_request(self): + # payload = json.dumps({ + # 'error': { + # 'status': 'INTERNAL', + # 'message': 'test INTERNAL error', + # 'details': [ + # { + # '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', + # 'errorCode': 'SOME_UNKNOWN_CODE', + # }, + # ], + # } + # }) + # responses = chain( + # [ + # httpx.ConnectError("Test request error", request=httpx.Request('POST', 'URL')) + # # respx.MockResponse(500, http_version='HTTP/2', content=payload), + # ], + # # repeat(respx.MockResponse(200, http_version='HTTP/2', json={'name': 'message-id1'})), + # # respx.MockResponse(200, http_version='HTTP/2', json={'name': 'message-id3'}), + # ) + + # # responses = repeat(respx.MockResponse(500, http_version='HTTP/2', content=payload)) + + # msg1 = messaging.Message(topic='foo1') + # route = respx.request('POST', 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send').mock(side_effect=responses) + # batch_response = await messaging.async_send_each([msg1], dry_run=True) + + # assert route.call_count == 1 + # assert batch_response.success_count == 0 + # assert batch_response.failure_count == 1 + # assert len(batch_response.responses) == 1 + # exception = batch_response.responses[0].exception + # assert isinstance(exception, exceptions.UnavailableError) + + # # assert route.call_count == 4 + # # assert batch_response.success_count == 1 + # # assert batch_response.failure_count == 0 + # # assert len(batch_response.responses) == 1 + # # assert [r.message_id for r in batch_response.responses] == ['message-id1'] + # # assert all([r.success for r in batch_response.responses]) + # # assert not any([r.exception for r in batch_response.responses]) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_each_detailed_error(self, status): success_payload = json.dumps({'name': 'message-id'}) From e3aa0db9a1fd35358c0a1dcb89435649441f9520 Mon Sep 17 00:00:00 2001 From: Jonathan Edey Date: Tue, 8 Apr 2025 10:24:06 -0400 Subject: [PATCH 2/5] Clean up code and lint --- firebase_admin/_utils.py | 9 +-- firebase_admin/messaging.py | 138 ++++++++++++++++++---------------- integration/conftest.py | 3 +- integration/test_messaging.py | 14 ++-- tests/test_messaging.py | 101 +++++++++++++++---------- 5 files changed, 146 insertions(+), 119 deletions(-) diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index f35e7cf65..765d11587 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -16,7 +16,7 @@ import json from platform import python_version -from typing import Callable, Optional, Union +from typing import Callable, Optional import google.auth import requests @@ -131,8 +131,8 @@ def handle_platform_error_from_requests(error, handle_func=None): return exc if exc else _handle_func_requests(error, message, error_dict) def handle_platform_error_from_httpx( - error: httpx.HTTPError, - handle_func: Optional[Callable[...,Optional[exceptions.FirebaseError]]] = None + error: httpx.HTTPError, + handle_func: Optional[Callable[..., Optional[exceptions.FirebaseError]]] = None ) -> exceptions.FirebaseError: """Constructs a ``FirebaseError`` from the given httpx error. @@ -158,8 +158,7 @@ def handle_platform_error_from_httpx( exc = handle_func(error, message, error_dict) return exc if exc else _handle_func_httpx(error, message, error_dict) - else: - return handle_httpx_error(error) + return handle_httpx_error(error) def handle_operation_error(error): diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index 6ce839c89..4a2b16642 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -15,25 +15,29 @@ """Firebase Cloud Messaging module.""" from __future__ import annotations -from typing import Callable, List, Optional, TypeVar +from typing import Callable, List, Optional import concurrent.futures import json import warnings +import asyncio import requests import httpx -import asyncio -from google.auth import credentials, transport +from google.auth import credentials +from google.auth.transport import requests as auth_requests from googleapiclient import http from googleapiclient import _auth import firebase_admin -from firebase_admin import _http_client -from firebase_admin import _messaging_encoder -from firebase_admin import _messaging_utils -from firebase_admin import _gapic_utils -from firebase_admin import _utils -from firebase_admin import exceptions +from firebase_admin import ( + _http_client, + _messaging_encoder, + _messaging_utils, + _gapic_utils, + _utils, + exceptions, + App +) _MESSAGING_ATTRIBUTE = '_messaging' @@ -67,17 +71,16 @@ 'WebpushNotification', 'WebpushNotificationAction', - 'async_send_each' 'send', 'send_all', 'send_multicast', 'send_each', + 'send_each_async', 'send_each_for_multicast', 'subscribe_to_topic', 'unsubscribe_from_topic', -] # type: ignore +] -TFirebaseError = TypeVar('TFirebaseError', bound=exceptions.FirebaseError) AndroidConfig = _messaging_utils.AndroidConfig AndroidFCMOptions = _messaging_utils.AndroidFCMOptions @@ -104,10 +107,10 @@ UnregisteredError = _messaging_utils.UnregisteredError -def _get_messaging_service(app) -> _MessagingService: +def _get_messaging_service(app: Optional[App]) -> _MessagingService: return _utils.get_app_service(app, _MESSAGING_ATTRIBUTE, _MessagingService) -def send(message, dry_run=False, app=None): +def send(message, dry_run=False, app: Optional[App] = None): """Sends the given message via Firebase Cloud Messaging (FCM). If the ``dry_run`` mode is enabled, the message will not be actually delivered to the @@ -147,8 +150,8 @@ def send_each(messages, dry_run=False, app=None): """ return _get_messaging_service(app).send_each(messages, dry_run) -async def async_send_each(messages, dry_run=True, app: firebase_admin.App | None = None) -> BatchResponse: - return await _get_messaging_service(app).async_send_each(messages, dry_run) +async def send_each_async(messages, dry_run=True, app: Optional[App] = None) -> BatchResponse: + return await _get_messaging_service(app).send_each_async(messages, dry_run) def send_each_for_multicast(multicast_message, dry_run=False, app=None): """Sends the given mutlicast message to each token via Firebase Cloud Messaging (FCM). @@ -374,20 +377,26 @@ def exception(self): return self._exception # Auth Flow +# TODO: Remove comments # The aim here is to be able to get auth credentials right before the request is sent. # This is similar to what is done in transport.requests.AuthorizedSession(). # We can then pass this in at the client level. -class CustomGoogleAuth(httpx.Auth): - def __init__(self, credentials: credentials.Credentials): - self._credential = credentials + +# Notes: +# - This implementations does not cover timeouts on requests sent to refresh credentials. +# - Uses HTTP/1 and a blocking credential for refreshing. +class GoogleAuthCredentialFlow(httpx.Auth): + """Google Auth Credential Auth Flow""" + def __init__(self, credential: credentials.Credentials): + self._credential = credential self._max_refresh_attempts = 2 self._refresh_status_codes = (401,) - + def apply_auth_headers(self, request: httpx.Request): # Build request used to refresh credentials if needed - auth_request = transport.requests.Request() # type: ignore - # This refreshes the credentials if needed and mutates the request headers to contain access token - # and any other google auth headers + auth_request = auth_requests.Request() + # This refreshes the credentials if needed and mutates the request headers to + # contain access token and any other google auth headers self._credential.before_request(auth_request, request.method, request.url, request.headers) @@ -395,27 +404,26 @@ def auth_flow(self, request: httpx.Request): # Keep original headers since `credentials.before_request` mutates the passed headers and we # want to keep the original in cause we need an auth retry. _original_headers = request.headers.copy() - + _credential_refresh_attempt = 0 - while ( - _credential_refresh_attempt < self._max_refresh_attempts - ): + while _credential_refresh_attempt <= self._max_refresh_attempts: # copy original headers request.headers = _original_headers.copy() # mutates request headers self.apply_auth_headers(request) - + # Continue to perform the request # yield here dispatches the request and returns with the response response: httpx.Response = yield request - - # We can check the result of the response and determine in we need to retry on refreshable status codes. - # Current transport.requests.AuthorizedSession() only does this on 401 errors. We should do the same. + + # We can check the result of the response and determine in we need to retry + # on refreshable status codes. Current transport.requests.AuthorizedSession() + # only does this on 401 errors. We should do the same. if response.status_code in self._refresh_status_codes: _credential_refresh_attempt += 1 - print(response.status_code, response.reason_phrase, _credential_refresh_attempt) else: - break; + break + # Last yielded response is auto returned. @@ -453,7 +461,7 @@ def __init__(self, app) -> None: self._client = _http_client.JsonHttpClient(credential=self._credential, timeout=timeout) self._async_client = httpx.AsyncClient( http2=True, - auth=CustomGoogleAuth(self._credential), + auth=GoogleAuthCredentialFlow(self._credential), timeout=timeout, transport=HttpxRetryTransport() ) @@ -509,13 +517,13 @@ def send_data(data): message='Unknown error while making remote service calls: {0}'.format(error), cause=error) - async def async_send_each(self, messages: List[Message], dry_run: bool = True) -> BatchResponse: + async def send_each_async(self, messages: List[Message], dry_run: bool = True) -> BatchResponse: """Sends the given messages to FCM via the FCM v1 API.""" if not isinstance(messages, list): raise ValueError('messages must be a list of messaging.Message instances.') if len(messages) > 1000: raise ValueError('messages must not contain more than 500 elements.') - + async def send_data(data): try: resp = await self._async_client.request( @@ -661,7 +669,8 @@ def _handle_batch_error(self, error): """Handles errors received from the googleapiclient while making batch requests.""" return _gapic_utils.handle_platform_error_from_googleapiclient( error, _MessagingService._build_fcm_error_googleapiclient) - + + # TODO: Remove comments # We should be careful to clean up the httpx clients. # Since we are using an async client we must also close in async. However we can sync wrap this. # The close method is called by the app on shutdown/clean-up of each service. We don't seem to @@ -677,14 +686,16 @@ def _build_fcm_error_requests(cls, error, message, error_dict): return exc_type(message, cause=error, http_response=error.response) if exc_type else None @classmethod - def _build_fcm_error_httpx(cls, error: httpx.HTTPError, message, error_dict) -> Optional[exceptions.FirebaseError]: + def _build_fcm_error_httpx( + cls, error: httpx.HTTPError, message, error_dict + ) -> Optional[exceptions.FirebaseError]: """Parses a httpx error response from the FCM API and creates a FCM-specific exception if appropriate.""" exc_type = cls._build_fcm_error(error_dict) if isinstance(error, httpx.HTTPStatusError): - return exc_type(message, cause=error, http_response=error.response) if exc_type else None - else: - return exc_type(message, cause=error) if exc_type else None + return exc_type( + message, cause=error, http_response=error.response) if exc_type else None + return exc_type(message, cause=error) if exc_type else None @classmethod @@ -706,42 +717,43 @@ def _build_fcm_error(cls, error_dict) -> Optional[Callable[..., exceptions.Fireb return _MessagingService.FCM_ERROR_TYPES.get(fcm_code) if fcm_code else None +# TODO: Remove comments +# Notes: +# This implementation currently only covers basic retires for pre-defined status errors class HttpxRetryTransport(httpx.AsyncBaseTransport): + """HTTPX transport with retry logic.""" # We could also support passing kwargs here - def __init__(self) -> None: + def __init__(self, **kwargs) -> None: + # Hardcoded settings for now self._retryable_status_codes = (500, 503,) self._max_retry_count = 4 - # We should use a full AsyncHTTPTransport under the hood since that is - # fully implemented. We could consider making this class extend a - # AsyncHTTPTransport instead and use the parent class's methods to handle - # requests. We sould also ensure that that transport's internal retry is - # not enabled. - self._wrapped_transport = httpx.AsyncHTTPTransport(retries=0, http2=True) - - # Checklist: - # - Do we want to disable built in retries - # - Can we dispatch the same request multiple times? Is there any side effects? - - # Two types of retries - # - Status code (500s, redirect) - # - Error code (read, connect, other) - # - more ??? - + # - We use a full AsyncHTTPTransport under the hood to make use of it's + # fully implemented `handle_async_request()`. + # - We could consider making the `HttpxRetryTransport`` class extend a + # `AsyncHTTPTransport` instead and use the parent class's methods to handle + # requests. + # - We should also ensure that that transport's internal retry is + # not enabled. + transport_kwargs = kwargs.copy() + transport_kwargs.update({'retries': 0, 'http2': True}) + self._wrapped_transport = httpx.AsyncHTTPTransport(**transport_kwargs) + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: _retry_count = 0 - + while True: # Dispatch request + # Let exceptions pass through for now response = await self._wrapped_transport.handle_async_request(request) - + # Check if request is retryable if response.status_code in self._retryable_status_codes: _retry_count += 1 - - # Figure out how we want to handle 0 here + + # Return if retries exhausted if _retry_count > self._max_retry_count: return response else: return response - # break; \ No newline at end of file diff --git a/integration/conftest.py b/integration/conftest.py index ee9509017..bdecca40e 100644 --- a/integration/conftest.py +++ b/integration/conftest.py @@ -82,9 +82,8 @@ def api_key(request): # yield loop # loop.close() -# def pytest_collection_modifyitems(items): pytest_asyncio_tests = (item for item in items if is_async_test(item)) session_scope_marker = pytest.mark.asyncio(loop_scope="session") for async_test in pytest_asyncio_tests: - async_test.add_marker(session_scope_marker, append=False) \ No newline at end of file + async_test.add_marker(session_scope_marker, append=False) diff --git a/integration/test_messaging.py b/integration/test_messaging.py index dfd2a8684..af35ce01b 100644 --- a/integration/test_messaging.py +++ b/integration/test_messaging.py @@ -14,7 +14,6 @@ """Integration tests for firebase_admin.messaging module.""" -import asyncio import re from datetime import datetime @@ -224,7 +223,7 @@ def test_unsubscribe(): assert resp.success_count + resp.failure_count == 1 @pytest.mark.asyncio -async def test_async_send_each(): +async def test_send_each_async(): messages = [ messaging.Message( topic='foo-bar', notification=messaging.Notification('Title', 'Body')), @@ -234,7 +233,7 @@ async def test_async_send_each(): token='not-a-token', notification=messaging.Notification('Title', 'Body')), ] - batch_response = await messaging.async_send_each(messages, dry_run=True) + batch_response = await messaging.send_each_async(messages, dry_run=True) assert batch_response.success_count == 2 assert batch_response.failure_count == 1 @@ -257,7 +256,7 @@ async def test_async_send_each(): # @pytest.mark.asyncio -# async def test_async_send_each_error(): +# async def test_send_each_async_error(): # messages = [ # messaging.Message( # topic='foo-bar', notification=messaging.Notification('Title', 'Body')), @@ -267,7 +266,7 @@ async def test_async_send_each(): # token='not-a-token', notification=messaging.Notification('Title', 'Body')), # ] -# batch_response = await messaging.async_send_each(messages, dry_run=True) +# batch_response = await messaging.send_each_async(messages, dry_run=True) # assert batch_response.success_count == 2 # assert batch_response.failure_count == 1 @@ -289,13 +288,13 @@ async def test_async_send_each(): # assert response.message_id is None @pytest.mark.asyncio -async def test_async_send_each_500(): +async def test_send_each_async_500(): messages = [] for msg_number in range(500): topic = 'foo-bar-{0}'.format(msg_number % 10) messages.append(messaging.Message(topic=topic)) - batch_response = await messaging.async_send_each(messages, dry_run=True) + batch_response = await messaging.send_each_async(messages, dry_run=True) assert batch_response.success_count == 500 assert batch_response.failure_count == 0 @@ -304,4 +303,3 @@ async def test_async_send_each_500(): assert response.success is True assert response.exception is None assert re.match('^projects/.*/messages/.*$', response.message_id) - diff --git a/tests/test_messaging.py b/tests/test_messaging.py index f4ce943e4..aee3ae4e6 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -22,7 +22,6 @@ from googleapiclient import http from googleapiclient import _helpers import pytest -import httpx import firebase_admin from firebase_admin import exceptions @@ -1929,7 +1928,7 @@ def test_send_each(self): @respx.mock @pytest.mark.asyncio - async def test_async_send_each(self): + async def test_send_each_async(self): responses = [ respx.MockResponse(200, http_version='HTTP/2', json={'name': 'message-id1'}), respx.MockResponse(200, http_version='HTTP/2', json={'name': 'message-id2'}), @@ -1938,20 +1937,24 @@ async def test_async_send_each(self): msg1 = messaging.Message(topic='foo1') msg2 = messaging.Message(topic='foo2') msg3 = messaging.Message(topic='foo3') - route = respx.request('POST', 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send').mock(side_effect=responses) - - batch_response = await messaging.async_send_each([msg1, msg2, msg3], dry_run=True) - + route = respx.request( + 'POST', + 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send' + ).mock(side_effect=responses) + + batch_response = await messaging.send_each_async([msg1, msg2, msg3], dry_run=True) + # try: - # batch_response = await messaging.async_send_each([msg1, msg2], dry_run=True) + # batch_response = await messaging.send_each_async([msg1, msg2], dry_run=True) # except Exception as error: # if isinstance(error.cause.__cause__, StopIteration): # raise Exception('Received more requests than mocks') - + assert batch_response.success_count == 3 assert batch_response.failure_count == 0 assert len(batch_response.responses) == 3 - assert [r.message_id for r in batch_response.responses] == ['message-id1', 'message-id2', 'message-id3'] + assert [r.message_id for r in batch_response.responses] \ + == ['message-id1', 'message-id2', 'message-id3'] assert all([r.success for r in batch_response.responses]) assert not any([r.exception for r in batch_response.responses]) @@ -1959,7 +1962,7 @@ async def test_async_send_each(self): @respx.mock @pytest.mark.asyncio - async def test_async_send_each_error_401_fail_auth_retry(self): + async def test_send_each_async_error_401_fail_auth_retry(self): payload = json.dumps({ 'error': { 'status': 'UNAUTHENTICATED', @@ -1972,23 +1975,26 @@ async def test_async_send_each_error_401_fail_auth_retry(self): ], } }) - + responses = repeat(respx.MockResponse(401, http_version='HTTP/2', content=payload)) - + msg1 = messaging.Message(topic='foo1') - route = respx.request('POST', 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send').mock(side_effect=responses) - batch_response = await messaging.async_send_each([msg1], dry_run=True) - - assert route.call_count == 2 + route = respx.request( + 'POST', + 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send' + ).mock(side_effect=responses) + batch_response = await messaging.send_each_async([msg1], dry_run=True) + + assert route.call_count == 3 assert batch_response.success_count == 0 assert batch_response.failure_count == 1 assert len(batch_response.responses) == 1 exception = batch_response.responses[0].exception assert isinstance(exception, exceptions.UnauthenticatedError) - + @respx.mock @pytest.mark.asyncio - async def test_async_send_each_error_401_pass_on_auth_retry(self): + async def test_send_each_async_error_401_pass_on_auth_retry(self): payload = json.dumps({ 'error': { 'status': 'UNAUTHENTICATED', @@ -2005,11 +2011,14 @@ async def test_async_send_each_error_401_pass_on_auth_retry(self): respx.MockResponse(401, http_version='HTTP/2', content=payload), respx.MockResponse(200, http_version='HTTP/2', json={'name': 'message-id1'}), ] - + msg1 = messaging.Message(topic='foo1') - route = respx.request('POST', 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send').mock(side_effect=responses) - batch_response = await messaging.async_send_each([msg1], dry_run=True) - + route = respx.request( + 'POST', + 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send' + ).mock(side_effect=responses) + batch_response = await messaging.send_each_async([msg1], dry_run=True) + assert route.call_count == 2 assert batch_response.success_count == 1 assert batch_response.failure_count == 0 @@ -2020,7 +2029,7 @@ async def test_async_send_each_error_401_pass_on_auth_retry(self): @respx.mock @pytest.mark.asyncio - async def test_async_send_each_error_500_fail_retry_config(self): + async def test_send_each_async_error_500_fail_retry_config(self): payload = json.dumps({ 'error': { 'status': 'INTERNAL', @@ -2033,13 +2042,16 @@ async def test_async_send_each_error_500_fail_retry_config(self): ], } }) - + responses = repeat(respx.MockResponse(500, http_version='HTTP/2', content=payload)) - + msg1 = messaging.Message(topic='foo1') - route = respx.request('POST', 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send').mock(side_effect=responses) - batch_response = await messaging.async_send_each([msg1], dry_run=True) - + route = respx.request( + 'POST', + 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send' + ).mock(side_effect=responses) + batch_response = await messaging.send_each_async([msg1], dry_run=True) + assert route.call_count == 5 assert batch_response.success_count == 0 assert batch_response.failure_count == 1 @@ -2050,7 +2062,7 @@ async def test_async_send_each_error_500_fail_retry_config(self): @respx.mock @pytest.mark.asyncio - async def test_async_send_each_error_500_pass_on_retry_config(self): + async def test_send_each_async_error_500_pass_on_retry_config(self): payload = json.dumps({ 'error': { 'status': 'INTERNAL', @@ -2072,11 +2084,14 @@ async def test_async_send_each_error_500_pass_on_retry_config(self): respx.MockResponse(200, http_version='HTTP/2', json={'name': 'message-id1'}), ], ) - + msg1 = messaging.Message(topic='foo1') - route = respx.request('POST', 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send').mock(side_effect=responses) - batch_response = await messaging.async_send_each([msg1], dry_run=True) - + route = respx.request( + 'POST', + 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send' + ).mock(side_effect=responses) + batch_response = await messaging.send_each_async([msg1], dry_run=True) + assert route.call_count == 5 assert batch_response.success_count == 1 assert batch_response.failure_count == 0 @@ -2087,7 +2102,7 @@ async def test_async_send_each_error_500_pass_on_retry_config(self): # @respx.mock # @pytest.mark.asyncio - # async def test_async_send_each_error_request(self): + # async def test_send_each_async_error_request(self): # payload = json.dumps({ # 'error': { # 'status': 'INTERNAL', @@ -2105,23 +2120,27 @@ async def test_async_send_each_error_500_pass_on_retry_config(self): # httpx.ConnectError("Test request error", request=httpx.Request('POST', 'URL')) # # respx.MockResponse(500, http_version='HTTP/2', content=payload), # ], - # # repeat(respx.MockResponse(200, http_version='HTTP/2', json={'name': 'message-id1'})), + # # repeat( + # respx.MockResponse(200, http_version='HTTP/2', json={'name': 'message-id1'})), # # respx.MockResponse(200, http_version='HTTP/2', json={'name': 'message-id3'}), # ) - + # # responses = repeat(respx.MockResponse(500, http_version='HTTP/2', content=payload)) - + # msg1 = messaging.Message(topic='foo1') - # route = respx.request('POST', 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send').mock(side_effect=responses) - # batch_response = await messaging.async_send_each([msg1], dry_run=True) - + # route = respx.request( + # 'POST', + # 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send' + # ).mock(side_effect=responses) + # batch_response = await messaging.send_each_async([msg1], dry_run=True) + # assert route.call_count == 1 # assert batch_response.success_count == 0 # assert batch_response.failure_count == 1 # assert len(batch_response.responses) == 1 # exception = batch_response.responses[0].exception # assert isinstance(exception, exceptions.UnavailableError) - + # # assert route.call_count == 4 # # assert batch_response.success_count == 1 # # assert batch_response.failure_count == 0 From a8aa968e6eb3b7b2ce6b1b52288c026452b6b480 Mon Sep 17 00:00:00 2001 From: Jonathan Edey Date: Tue, 8 Apr 2025 10:45:17 -0400 Subject: [PATCH 3/5] fix: Add extra dependancy for http2 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index c662dd53a..ba6f2f947 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,4 @@ google-api-python-client >= 1.7.8 google-cloud-firestore >= 2.19.0; platform.python_implementation != 'PyPy' google-cloud-storage >= 1.37.1 pyjwt[crypto] >= 2.5.0 -httpx == 0.28.1 \ No newline at end of file +httpx[http2] == 0.28.1 \ No newline at end of file From 3343e2529884dc18a1da731bea50f570bc32c6e0 Mon Sep 17 00:00:00 2001 From: Jonathan Edey Date: Tue, 8 Apr 2025 10:54:03 -0400 Subject: [PATCH 4/5] fix: reset message batch limit to 500 --- firebase_admin/messaging.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index 4a2b16642..abac5ae54 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -492,7 +492,7 @@ def send_each(self, messages, dry_run=False): """Sends the given messages to FCM via the FCM v1 API.""" if not isinstance(messages, list): raise ValueError('messages must be a list of messaging.Message instances.') - if len(messages) > 1000: + if len(messages) > 500: raise ValueError('messages must not contain more than 500 elements.') def send_data(data): @@ -521,7 +521,7 @@ async def send_each_async(self, messages: List[Message], dry_run: bool = True) - """Sends the given messages to FCM via the FCM v1 API.""" if not isinstance(messages, list): raise ValueError('messages must be a list of messaging.Message instances.') - if len(messages) > 1000: + if len(messages) > 500: raise ValueError('messages must not contain more than 500 elements.') async def send_data(data): From 7a5f245b0eeea3390b7966048ac94283bf246edb Mon Sep 17 00:00:00 2001 From: Jonathan Edey Date: Thu, 10 Apr 2025 11:11:37 -0400 Subject: [PATCH 5/5] fix: Add new import to `setup.py` --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 23be6d481..e92d207aa 100644 --- a/setup.py +++ b/setup.py @@ -43,6 +43,7 @@ 'google-cloud-firestore>=2.19.0; platform.python_implementation != "PyPy"', 'google-cloud-storage>=1.37.1', 'pyjwt[crypto] >= 2.5.0', + 'httpx[http2] == 0.28.1', ] setup(