From f69e14c38c6367527ff293ec40617be95d09015d Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Wed, 12 Jun 2019 14:32:08 -0700 Subject: [PATCH 01/37] Introduced the exceptions module (#296) * Added the exceptions module * Cleaned up the error handling logic; Added tests * Updated docs; Fixed some typos --- CHANGELOG.md | 5 +- firebase_admin/_utils.py | 48 +++++++++ firebase_admin/exceptions.py | 182 ++++++++++++++++++++++++++++++++ firebase_admin/instance_id.py | 15 +-- integration/test_instance_id.py | 3 +- tests/test_exceptions.py | 96 +++++++++++++++++ tests/test_instance_id.py | 51 +++++++-- 7 files changed, 376 insertions(+), 24 deletions(-) create mode 100644 firebase_admin/exceptions.py create mode 100644 tests/test_exceptions.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 5fc1a759e..86acfcdc7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,9 @@ # Unreleased -- +- [added] Added the new `firebase_admin.exceptions` module containing the + base exception types and global error codes. +- [changed] Updated the `firebase_admin.instance_id` module to use the new + shared exception types. The type `instance_id.ApiCallError` was removed. # v2.17.0 diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index b28853868..61fe8eb1c 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -14,7 +14,22 @@ """Internal utilities common to all modules.""" +import requests + import firebase_admin +from firebase_admin import exceptions + + +_STATUS_TO_EXCEPTION_TYPE = { + 400: exceptions.InvalidArgumentError, + 401: exceptions.UnauthenticatedError, + 403: exceptions.PermissionDeniedError, + 404: exceptions.NotFoundError, + 409: exceptions.ConflictError, + 429: exceptions.ResourceExhaustedError, + 500: exceptions.InternalError, + 503: exceptions.UnavailableError, +} def _get_initialized_app(app): @@ -33,3 +48,36 @@ def _get_initialized_app(app): def get_app_service(app, name, initializer): app = _get_initialized_app(app) return app._get_service(name, initializer) # pylint: disable=protected-access + +def handle_requests_error(error, message=None, status=None): + """Constructs a ``FirebaseError`` from the given requests error. + + Args: + error: An error raised by the reqests module while making an HTTP call. + message: A message to be included in the resulting ``FirebaseError`` (optional). If not + specified the string representation of the ``error`` argument is used as the message. + status: An HTTP status code that will be used to determine the resulting error type + (optional). If not specified the HTTP status code on the error response is used. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code. + """ + if isinstance(error, requests.exceptions.Timeout): + return exceptions.DeadlineExceededError( + message='Timed out while making an API call: {0}'.format(error), + cause=error) + elif isinstance(error, requests.exceptions.ConnectionError): + return exceptions.UnavailableError( + message='Failed to establish a connection: {0}'.format(error), + cause=error) + elif error.response is None: + return exceptions.UnknownError( + message='Unknown error while making a remote service call: {0}'.format(error), + cause=error) + + if not status: + status = error.response.status_code + if not message: + message = str(error) + err_type = _STATUS_TO_EXCEPTION_TYPE.get(status, exceptions.UnknownError) + return err_type(message=message, cause=error, http_response=error.response) diff --git a/firebase_admin/exceptions.py b/firebase_admin/exceptions.py new file mode 100644 index 000000000..f1297dbb3 --- /dev/null +++ b/firebase_admin/exceptions.py @@ -0,0 +1,182 @@ +# Copyright 2019 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Firebase Exceptions module. + +This module defines the base types for exceptions and the platform-wide error codes as outlined in +https://cloud.google.com/apis/design/errors. +""" + + +INVALID_ARGUMENT = 'INVALID_ARGUMENT' +FAILED_PRECONDITION = 'FAILED_PRECONDITION' +OUT_OF_RANGE = 'OUT_OF_RANGE' +UNAUTHENTICATED = 'UNAUTHENTICATED' +PERMISSION_DENIED = 'PERMISSION_DENIED' +NOT_FOUND = 'NOT_FOUND' +CONFLICT = 'CONFLICT' +ABORTED = 'ABORTED' +ALREADY_EXISTS = 'ALREADY_EXISTS' +RESOURCE_EXHAUSTED = 'RESOURCE_EXHAUSTED' +CANCELLED = 'CANCELLED' +DATA_LOSS = 'DATA_LOSS' +UNKNOWN = 'UNKNOWN' +INTERNAL = 'INTERNAL' +UNAVAILABLE = 'UNAVAILABLE' +DEADLINE_EXCEEDED = 'DEADLINE_EXCEEDED' + + +class FirebaseError(Exception): + """Base class for all errors raised by the Admin SDK.""" + + def __init__(self, code, message, cause=None, http_response=None): + Exception.__init__(self, message) + self._code = code + self._cause = cause + self._http_response = http_response + + @property + def code(self): + return self._code + + @property + def cause(self): + return self._cause + + @property + def http_response(self): + return self._http_response + + +class InvalidArgumentError(FirebaseError): + """Client specified an invalid argument.""" + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, INVALID_ARGUMENT, message, cause, http_response) + + +class FailedPreconditionError(FirebaseError): + """Request can not be executed in the current system state, such as deleting a non-empty + directory.""" + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, FAILED_PRECONDITION, message, cause, http_response) + + +class OutOfRangeError(FirebaseError): + """Client specified an invalid range.""" + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, OUT_OF_RANGE, message, cause, http_response) + + +class UnauthenticatedError(FirebaseError): + """Request not authenticated due to missing, invalid, or expired OAuth token.""" + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, UNAUTHENTICATED, message, cause, http_response) + + +class PermissionDeniedError(FirebaseError): + """Client does not have sufficient permission. + + This can happen because the OAuth token does not have the right scopes, the client doesn't + have permission, or the API has not been enabled for the client project. + """ + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, PERMISSION_DENIED, message, cause, http_response) + + +class NotFoundError(FirebaseError): + """A specified resource is not found, or the request is rejected by undisclosed reasons, such + as whitelisting.""" + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, NOT_FOUND, message, cause, http_response) + + +class ConflictError(FirebaseError): + """Concurrency conflict, such as read-modify-write conflict.""" + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, CONFLICT, message, cause, http_response) + + +class AbortedError(FirebaseError): + """Concurrency conflict, such as read-modify-write conflict.""" + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, ABORTED, message, cause, http_response) + + +class AlreadyExistsError(FirebaseError): + """The resource that a client tried to create already exists.""" + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, ALREADY_EXISTS, message, cause, http_response) + + +class ResourceExhaustedError(FirebaseError): + """Either out of resource quota or reaching rate limiting.""" + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, RESOURCE_EXHAUSTED, message, cause, http_response) + + +class CancelledError(FirebaseError): + """Request cancelled by the client.""" + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, CANCELLED, message, cause, http_response) + + +class DataLossError(FirebaseError): + """Unrecoverable data loss or data corruption.""" + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, DATA_LOSS, message, cause, http_response) + + +class UnknownError(FirebaseError): + """Unknown server error.""" + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, UNKNOWN, message, cause, http_response) + + +class InternalError(FirebaseError): + """Internal server error.""" + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, INTERNAL, message, cause, http_response) + + +class UnavailableError(FirebaseError): + """Service unavailable. Typically the server is down.""" + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, UNAVAILABLE, message, cause, http_response) + + +class DeadlineExceededError(FirebaseError): + """Request deadline exceeded. + + This will happen only if the caller sets a deadline that is shorter than the method's + default deadline (i.e. requested deadline is not enough for the server to process the + request) and the request did not finish within the deadline. + """ + + def __init__(self, message, cause=None, http_response=None): + FirebaseError.__init__(self, DEADLINE_EXCEEDED, message, cause, http_response) diff --git a/firebase_admin/instance_id.py b/firebase_admin/instance_id.py index b290e9e7f..e9134fc28 100644 --- a/firebase_admin/instance_id.py +++ b/firebase_admin/instance_id.py @@ -53,14 +53,6 @@ def delete_instance_id(instance_id, app=None): _get_iid_service(app).delete_instance_id(instance_id) -class ApiCallError(Exception): - """Represents an Exception encountered while invoking the Firebase instance ID service.""" - - def __init__(self, message, error): - Exception.__init__(self, message) - self.detail = error - - class _InstanceIdService(object): """Provides methods for interacting with the remote instance ID service.""" @@ -94,14 +86,15 @@ def delete_instance_id(self, instance_id): try: self._client.request('delete', path) except requests.exceptions.RequestException as error: - raise ApiCallError(self._extract_message(instance_id, error), error) + msg = self._extract_message(instance_id, error) + raise _utils.handle_requests_error(error, msg) def _extract_message(self, instance_id, error): if error.response is None: - return str(error) + return None status = error.response.status_code msg = self.error_codes.get(status) if msg: return 'Instance ID "{0}": {1}'.format(instance_id, msg) else: - return str(error) + return 'Instance ID "{0}": {1}'.format(instance_id, error) diff --git a/integration/test_instance_id.py b/integration/test_instance_id.py index 1a176a9a0..99b6787d3 100644 --- a/integration/test_instance_id.py +++ b/integration/test_instance_id.py @@ -16,10 +16,11 @@ import pytest +from firebase_admin import exceptions from firebase_admin import instance_id def test_delete_non_existing(): - with pytest.raises(instance_id.ApiCallError) as excinfo: + with pytest.raises(exceptions.NotFoundError) as excinfo: # legal instance IDs are /[cdef][A-Za-z0-9_-]{9}[AEIMQUYcgkosw048]/ instance_id.delete_instance_id('fictive-ID0') assert str(excinfo.value) == 'Instance ID "fictive-ID0": Failed to find the instance ID.' diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py new file mode 100644 index 000000000..f2897ab3c --- /dev/null +++ b/tests/test_exceptions.py @@ -0,0 +1,96 @@ +# Copyright 2019 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import requests +from requests import models + +from firebase_admin import exceptions +from firebase_admin import _utils + + +def test_timeout_error(): + error = requests.exceptions.Timeout('Test error') + firebase_error = _utils.handle_requests_error(error) + assert isinstance(firebase_error, exceptions.DeadlineExceededError) + assert str(firebase_error) == 'Timed out while making an API call: Test error' + assert firebase_error.cause is error + assert firebase_error.http_response is None + +def test_connection_error(): + error = requests.exceptions.ConnectionError('Test error') + firebase_error = _utils.handle_requests_error(error) + assert isinstance(firebase_error, exceptions.UnavailableError) + assert str(firebase_error) == 'Failed to establish a connection: Test error' + assert firebase_error.cause is error + assert firebase_error.http_response is None + +def test_unknown_transport_error(): + error = requests.exceptions.RequestException('Test error') + firebase_error = _utils.handle_requests_error(error) + assert isinstance(firebase_error, exceptions.UnknownError) + assert str(firebase_error) == 'Unknown error while making a remote service call: Test error' + assert firebase_error.cause is error + assert firebase_error.http_response is None + +def test_http_response(): + resp = models.Response() + resp.status_code = 500 + error = requests.exceptions.RequestException('Test error', response=resp) + firebase_error = _utils.handle_requests_error(error) + assert isinstance(firebase_error, exceptions.InternalError) + assert str(firebase_error) == 'Test error' + assert firebase_error.cause is error + assert firebase_error.http_response is resp + +def test_http_response_with_unknown_status(): + resp = models.Response() + resp.status_code = 501 + error = requests.exceptions.RequestException('Test error', response=resp) + firebase_error = _utils.handle_requests_error(error) + assert isinstance(firebase_error, exceptions.UnknownError) + assert str(firebase_error) == 'Test error' + assert firebase_error.cause is error + assert firebase_error.http_response is resp + +def test_http_response_with_message(): + resp = models.Response() + resp.status_code = 500 + error = requests.exceptions.RequestException('Test error', response=resp) + firebase_error = _utils.handle_requests_error(error, message='Explicit error message') + assert isinstance(firebase_error, exceptions.InternalError) + assert str(firebase_error) == 'Explicit error message' + assert firebase_error.cause is error + assert firebase_error.http_response is resp + +def test_http_response_with_status(): + resp = models.Response() + resp.status_code = 500 + error = requests.exceptions.RequestException('Test error', response=resp) + firebase_error = _utils.handle_requests_error(error, status=503) + assert isinstance(firebase_error, exceptions.UnavailableError) + assert str(firebase_error) == 'Test error' + assert firebase_error.cause is error + assert firebase_error.http_response is resp + +def test_http_response_with_message_and_status(): + resp = models.Response() + resp.status_code = 500 + error = requests.exceptions.RequestException('Test error', response=resp) + firebase_error = _utils.handle_requests_error( + error, message='Explicit error message', status=503) + assert isinstance(firebase_error, exceptions.UnavailableError) + assert str(firebase_error) == 'Explicit error message' + assert firebase_error.cause is error + assert firebase_error.http_response is resp diff --git a/tests/test_instance_id.py b/tests/test_instance_id.py index e8e8edd27..83e66491a 100644 --- a/tests/test_instance_id.py +++ b/tests/test_instance_id.py @@ -17,15 +17,37 @@ import pytest import firebase_admin +from firebase_admin import exceptions from firebase_admin import instance_id from tests import testutils http_errors = { - 404: 'Instance ID "test_iid": Failed to find the instance ID.', - 409: 'Instance ID "test_iid": Already deleted.', - 429: 'Instance ID "test_iid": Request throttled out by the backend server.', - 500: 'Instance ID "test_iid": Internal server error.', + 400: ( + 'Instance ID "test_iid": Malformed instance ID argument.', + exceptions.InvalidArgumentError), + 401: ( + 'Instance ID "test_iid": Request not authorized.', + exceptions.UnauthenticatedError), + 403: ( + ('Instance ID "test_iid": Project does not match instance ID or the client does not have ' + 'sufficient privileges.'), + exceptions.PermissionDeniedError), + 404: ( + 'Instance ID "test_iid": Failed to find the instance ID.', + exceptions.NotFoundError), + 409: ( + 'Instance ID "test_iid": Already deleted.', + exceptions.ConflictError), + 429: ( + 'Instance ID "test_iid": Request throttled out by the backend server.', + exceptions.ResourceExhaustedError), + 500: ( + 'Instance ID "test_iid": Internal server error.', + exceptions.InternalError), + 503: ( + 'Instance ID "test_iid": Backend servers are over capacity. Try again later.', + exceptions.UnavailableError), } class TestDeleteInstanceId(object): @@ -74,11 +96,17 @@ def test_delete_instance_id_error(self, status): cred = testutils.MockCredential() app = firebase_admin.initialize_app(cred, {'projectId': 'explicit-project-id'}) _, recorder = self._instrument_iid_service(app, status, 'some error') - with pytest.raises(instance_id.ApiCallError) as excinfo: + msg, exc = http_errors.get(status) + with pytest.raises(exc) as excinfo: instance_id.delete_instance_id('test_iid') - assert str(excinfo.value) == http_errors.get(status) - assert excinfo.value.detail is not None - assert len(recorder) == 1 + assert str(excinfo.value) == msg + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None + if status != 401: + assert len(recorder) == 1 + else: + # 401 responses are automatically retried by google-auth + assert len(recorder) == 3 assert recorder[0].method == 'DELETE' assert recorder[0].url == self._get_url('explicit-project-id', 'test_iid') @@ -86,12 +114,13 @@ def test_delete_instance_id_unexpected_error(self): cred = testutils.MockCredential() app = firebase_admin.initialize_app(cred, {'projectId': 'explicit-project-id'}) _, recorder = self._instrument_iid_service(app, 501, 'some error') - with pytest.raises(instance_id.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnknownError) as excinfo: instance_id.delete_instance_id('test_iid') url = self._get_url('explicit-project-id', 'test_iid') - message = '501 Server Error: None for url: {0}'.format(url) + message = 'Instance ID "test_iid": 501 Server Error: None for url: {0}'.format(url) assert str(excinfo.value) == message - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 assert recorder[0].method == 'DELETE' assert recorder[0].url == url From 2879a22a0f1eb5ec7b72b900ac875a9d1b1238a3 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Thu, 20 Jun 2019 16:13:51 -0700 Subject: [PATCH 02/37] Migrating FCM Send APIs to the New Exceptions (#297) * Migrated FCM send APIs to the new error handling regime * Moved error parsing logic to _utils * Refactored OP error handling code * Fixing a broken test * Added utils for handling googleapiclient errors * Added tests for new error handling logic * Updated public API docs * Fixing test for python3 * Cleaning up the error code lookup code * Cleaning up the error parsing APIs * Cleaned up error parsing logic; Updated docs --- firebase_admin/_messaging_utils.py | 32 +++ firebase_admin/_utils.py | 242 ++++++++++++++++-- firebase_admin/messaging.py | 105 +++----- integration/test_messaging.py | 21 +- tests/test_exceptions.py | 387 +++++++++++++++++++++++------ tests/test_messaging.py | 187 ++++++++------ 6 files changed, 742 insertions(+), 232 deletions(-) diff --git a/firebase_admin/_messaging_utils.py b/firebase_admin/_messaging_utils.py index 17067f175..127221367 100644 --- a/firebase_admin/_messaging_utils.py +++ b/firebase_admin/_messaging_utils.py @@ -22,6 +22,8 @@ import six +from firebase_admin import exceptions + class Message(object): """A message that can be sent via Firebase Cloud Messaging. @@ -797,3 +799,33 @@ def default(self, obj): # pylint: disable=method-hidden if target_count != 1: raise ValueError('Exactly one of token, topic or condition must be specified.') return result + + +class ThirdPartyAuthError(exceptions.UnauthenticatedError): + """APNs certificate or web push auth key was invalid or missing.""" + + def __init__(self, message, cause=None, http_response=None): + exceptions.UnauthenticatedError.__init__(self, message, cause, http_response) + + +class QuotaExceededError(exceptions.ResourceExhaustedError): + """Sending limit exceeded for the message target.""" + + def __init__(self, message, cause=None, http_response=None): + exceptions.ResourceExhaustedError.__init__(self, message, cause, http_response) + + +class SenderIdMismatchError(exceptions.PermissionDeniedError): + """The authenticated sender ID is different from the sender ID for the registration token.""" + + def __init__(self, message, cause=None, http_response=None): + exceptions.PermissionDeniedError.__init__(self, message, cause, http_response) + + +class UnregisteredError(exceptions.NotFoundError): + """App instance was unregistered from FCM. + + This usually means that the token used is no longer valid and a new one must be used.""" + + def __init__(self, message, cause=None, http_response=None): + exceptions.NotFoundError.__init__(self, message, cause, http_response) diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index 61fe8eb1c..42b83809e 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -14,21 +14,47 @@ """Internal utilities common to all modules.""" +import json +import socket + +import googleapiclient +import httplib2 import requests +import six import firebase_admin from firebase_admin import exceptions -_STATUS_TO_EXCEPTION_TYPE = { - 400: exceptions.InvalidArgumentError, - 401: exceptions.UnauthenticatedError, - 403: exceptions.PermissionDeniedError, - 404: exceptions.NotFoundError, - 409: exceptions.ConflictError, - 429: exceptions.ResourceExhaustedError, - 500: exceptions.InternalError, - 503: exceptions.UnavailableError, +_ERROR_CODE_TO_EXCEPTION_TYPE = { + exceptions.INVALID_ARGUMENT: exceptions.InvalidArgumentError, + exceptions.FAILED_PRECONDITION: exceptions.FailedPreconditionError, + exceptions.OUT_OF_RANGE: exceptions.OutOfRangeError, + exceptions.UNAUTHENTICATED: exceptions.UnauthenticatedError, + exceptions.PERMISSION_DENIED: exceptions.PermissionDeniedError, + exceptions.NOT_FOUND: exceptions.NotFoundError, + exceptions.ABORTED: exceptions.AbortedError, + exceptions.ALREADY_EXISTS: exceptions.AlreadyExistsError, + exceptions.CONFLICT: exceptions.ConflictError, + exceptions.RESOURCE_EXHAUSTED: exceptions.ResourceExhaustedError, + exceptions.CANCELLED: exceptions.CancelledError, + exceptions.DATA_LOSS: exceptions.DataLossError, + exceptions.UNKNOWN: exceptions.UnknownError, + exceptions.INTERNAL: exceptions.InternalError, + exceptions.UNAVAILABLE: exceptions.UnavailableError, + exceptions.DEADLINE_EXCEEDED: exceptions.DeadlineExceededError, +} + + +_HTTP_STATUS_TO_ERROR_CODE = { + 400: exceptions.INVALID_ARGUMENT, + 401: exceptions.UNAUTHENTICATED, + 403: exceptions.PERMISSION_DENIED, + 404: exceptions.NOT_FOUND, + 409: exceptions.CONFLICT, + 429: exceptions.RESOURCE_EXHAUSTED, + 500: exceptions.INTERNAL, + 503: exceptions.UNAVAILABLE, } @@ -45,19 +71,69 @@ def _get_initialized_app(app): raise ValueError('Illegal app argument. Argument must be of type ' ' firebase_admin.App, but given "{0}".'.format(type(app))) + def get_app_service(app, name, initializer): app = _get_initialized_app(app) return app._get_service(name, initializer) # pylint: disable=protected-access -def handle_requests_error(error, message=None, status=None): + +def handle_platform_error_from_requests(error, handle_func=None): + """Constructs a ``FirebaseError`` from the given requests error. + + This can be used to handle errors returned by Google Cloud Platform (GCP) APIs. + + Args: + error: An error raised by the requests module while making an HTTP call to a GCP API. + handle_func: A function that can be used to handle platform errors in a custom way. When + specified, this function will be called with three arguments. It has the same + signature as ```_handle_func_requests``, but may return ``None``. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code. + """ + if error.response is None: + return handle_requests_error(error) + + response = error.response + content = response.content.decode() + status_code = response.status_code + error_dict, message = _parse_platform_error(content, status_code) + exc = None + if handle_func: + exc = handle_func(error, message, error_dict) + + return exc if exc else _handle_func_requests(error, message, error_dict) + + +def _handle_func_requests(error, message, error_dict): + """Constructs a ``FirebaseError`` from the given GCP error. + + Args: + error: An error raised by the requests module while making an HTTP call. + message: A message to be included in the resulting ``FirebaseError``. + error_dict: Parsed GCP error response. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code or None. + """ + code = error_dict.get('status') + return handle_requests_error(error, message, code) + + +def handle_requests_error(error, message=None, code=None): """Constructs a ``FirebaseError`` from the given requests error. + This method is agnostic of the remote service that produced the error, whether it is a GCP + service or otherwise. Therefore, this method does not attempt to parse the error response in + any way. + Args: - error: An error raised by the reqests module while making an HTTP call. + error: An error raised by the requests module while making an HTTP call. message: A message to be included in the resulting ``FirebaseError`` (optional). If not specified the string representation of the ``error`` argument is used as the message. - status: An HTTP status code that will be used to determine the resulting error type - (optional). If not specified the HTTP status code on the error response is used. + code: A GCP error code that will be used to determine the resulting error type (optional). + If not specified the HTTP status code on the error response is used to determine a + suitable error code. Returns: FirebaseError: A ``FirebaseError`` that can be raised to the user code. @@ -75,9 +151,143 @@ def handle_requests_error(error, message=None, status=None): message='Unknown error while making a remote service call: {0}'.format(error), cause=error) - if not status: - status = error.response.status_code + if not code: + code = _http_status_to_error_code(error.response.status_code) if not message: message = str(error) - err_type = _STATUS_TO_EXCEPTION_TYPE.get(status, exceptions.UnknownError) + + err_type = _error_code_to_exception_type(code) return err_type(message=message, cause=error, http_response=error.response) + + +def handle_platform_error_from_googleapiclient(error, handle_func=None): + """Constructs a ``FirebaseError`` from the given googleapiclient error. + + This can be used to handle errors returned by Google Cloud Platform (GCP) APIs. + + Args: + error: An error raised by the googleapiclient while making an HTTP call to a GCP API. + handle_func: A function that can be used to handle platform errors in a custom way. When + specified, this function will be called with three arguments. It has the same + signature as ```_handle_func_googleapiclient``, but may return ``None``. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code. + """ + if not isinstance(error, googleapiclient.errors.HttpError): + return handle_googleapiclient_error(error) + + content = error.content.decode() + status_code = error.resp.status + error_dict, message = _parse_platform_error(content, status_code) + http_response = _http_response_from_googleapiclient_error(error) + exc = None + if handle_func: + exc = handle_func(error, message, error_dict, http_response) + + return exc if exc else _handle_func_googleapiclient(error, message, error_dict, http_response) + + +def _handle_func_googleapiclient(error, message, error_dict, http_response): + """Constructs a ``FirebaseError`` from the given GCP error. + + Args: + error: An error raised by the googleapiclient module while making an HTTP call. + message: A message to be included in the resulting ``FirebaseError``. + error_dict: Parsed GCP error response. + http_response: A requests HTTP response object to associate with the exception. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code or None. + """ + code = error_dict.get('status') + return handle_googleapiclient_error(error, message, code, http_response) + + +def handle_googleapiclient_error(error, message=None, code=None, http_response=None): + """Constructs a ``FirebaseError`` from the given googleapiclient error. + + This method is agnostic of the remote service that produced the error, whether it is a GCP + service or otherwise. Therefore, this method does not attempt to parse the error response in + any way. + + Args: + error: An error raised by the googleapiclient module while making an HTTP call. + message: A message to be included in the resulting ``FirebaseError`` (optional). If not + specified the string representation of the ``error`` argument is used as the message. + code: A GCP error code that will be used to determine the resulting error type (optional). + If not specified the HTTP status code on the error response is used to determine a + suitable error code. + http_response: A requests HTTP response object to associate with the exception (optional). + If not specified, one will be created from the ``error``. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code. + """ + if isinstance(error, socket.timeout) or ( + isinstance(error, socket.error) and 'timed out' in str(error)): + return exceptions.DeadlineExceededError( + message='Timed out while making an API call: {0}'.format(error), + cause=error) + elif isinstance(error, httplib2.ServerNotFoundError): + return exceptions.UnavailableError( + message='Failed to establish a connection: {0}'.format(error), + cause=error) + elif not isinstance(error, googleapiclient.errors.HttpError): + return exceptions.UnknownError( + message='Unknown error while making a remote service call: {0}'.format(error), + cause=error) + + if not code: + code = _http_status_to_error_code(error.resp.status) + if not message: + message = str(error) + if not http_response: + http_response = _http_response_from_googleapiclient_error(error) + + err_type = _error_code_to_exception_type(code) + return err_type(message=message, cause=error, http_response=http_response) + + +def _http_response_from_googleapiclient_error(error): + """Creates a requests HTTP Response object from the given googleapiclient error.""" + resp = requests.models.Response() + resp.raw = six.BytesIO(error.content) + resp.status_code = error.resp.status + return resp + + +def _http_status_to_error_code(status): + """Maps an HTTP status to a platform error code.""" + return _HTTP_STATUS_TO_ERROR_CODE.get(status, exceptions.UNKNOWN) + + +def _error_code_to_exception_type(code): + """Maps a platform error code to an exception type.""" + return _ERROR_CODE_TO_EXCEPTION_TYPE.get(code, exceptions.UnknownError) + + +def _parse_platform_error(content, status_code): + """Parses an HTTP error response from a Google Cloud Platform API and extracts the error code + and message fields. + + Args: + content: Decoded content of the response body. + status_code: HTTP status code. + + Returns: + tuple: A tuple containing error code and message. + """ + data = {} + try: + parsed_body = json.loads(content) + if isinstance(parsed_body, dict): + data = parsed_body + except ValueError: + pass + + error_dict = data.get('error', {}) + msg = error_dict.get('message') + if not msg: + msg = 'Unexpected HTTP response with status: {0}; body: {1}'.format(status_code, content) + return error_dict, msg diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index 35d9e4ccd..63cbbf4be 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -75,6 +75,11 @@ WebpushNotification = _messaging_utils.WebpushNotification WebpushNotificationAction = _messaging_utils.WebpushNotificationAction +QuotaExceededError = _messaging_utils.QuotaExceededError +SenderIdMismatchError = _messaging_utils.SenderIdMismatchError +ThirdPartyAuthError = _messaging_utils.ThirdPartyAuthError +UnregisteredError = _messaging_utils.UnregisteredError + def _get_messaging_service(app): return _utils.get_app_service(app, _MESSAGING_ATTRIBUTE, _MessagingService) @@ -94,7 +99,7 @@ def send(message, dry_run=False, app=None): string: A message ID string that uniquely identifies the sent the message. Raises: - ApiCallError: If an error occurs while sending the message to the FCM service. + FirebaseError: If an error occurs while sending the message to the FCM service. ValueError: If the input arguments are invalid. """ return _get_messaging_service(app).send(message, dry_run) @@ -114,7 +119,7 @@ def send_all(messages, dry_run=False, app=None): BatchResponse: A ``messaging.BatchResponse`` instance. Raises: - ApiCallError: If an error occurs while sending the message to the FCM service. + FirebaseError: If an error occurs while sending the message to the FCM service. ValueError: If the input arguments are invalid. """ return _get_messaging_service(app).send_all(messages, dry_run) @@ -134,7 +139,7 @@ def send_multicast(multicast_message, dry_run=False, app=None): BatchResponse: A ``messaging.BatchResponse`` instance. Raises: - ApiCallError: If an error occurs while sending the message to the FCM service. + FirebaseError: If an error occurs while sending the message to the FCM service. ValueError: If the input arguments are invalid. """ if not isinstance(multicast_message, MulticastMessage): @@ -310,21 +315,12 @@ class _MessagingService(object): INTERNAL_ERROR = 'internal-error' UNKNOWN_ERROR = 'unknown-error' - FCM_ERROR_CODES = { - # FCM v1 canonical error codes - 'NOT_FOUND': 'registration-token-not-registered', - 'PERMISSION_DENIED': 'mismatched-credential', - 'RESOURCE_EXHAUSTED': 'message-rate-exceeded', - 'UNAUTHENTICATED': 'invalid-apns-credentials', - - # FCM v1 new error codes - 'APNS_AUTH_ERROR': 'invalid-apns-credentials', - 'INTERNAL': INTERNAL_ERROR, - 'INVALID_ARGUMENT': 'invalid-argument', - 'QUOTA_EXCEEDED': 'message-rate-exceeded', - 'SENDER_ID_MISMATCH': 'mismatched-credential', - 'UNAVAILABLE': 'server-unavailable', - 'UNREGISTERED': 'registration-token-not-registered', + FCM_ERROR_TYPES = { + 'APNS_AUTH_ERROR': ThirdPartyAuthError, + 'QUOTA_EXCEEDED': QuotaExceededError, + 'SENDER_ID_MISMATCH': SenderIdMismatchError, + 'THIRD_PARTY_AUTH_ERROR': ThirdPartyAuthError, + 'UNREGISTERED': UnregisteredError, } IID_ERROR_CODES = { 400: 'invalid-argument', @@ -367,11 +363,7 @@ def send(self, message, dry_run=False): timeout=self._timeout ) except requests.exceptions.RequestException as error: - if error.response is not None: - self._handle_fcm_error(error) - else: - msg = 'Failed to call messaging API: {0}'.format(error) - raise ApiCallError(self.INTERNAL_ERROR, msg, error) + raise self._handle_fcm_error(error) else: return resp['name'] @@ -387,7 +379,7 @@ def send_all(self, messages, dry_run=False): def batch_callback(_, response, error): exception = None if error: - exception = self._parse_batch_error(error) + exception = self._handle_batch_error(error) send_response = SendResponse(response, exception) responses.append(send_response) @@ -407,7 +399,7 @@ def batch_callback(_, response, error): try: batch.execute() except googleapiclient.http.HttpError as error: - raise self._parse_batch_error(error) + raise self._handle_batch_error(error) else: return BatchResponse(responses) @@ -459,17 +451,8 @@ def _postproc(self, _, body): def _handle_fcm_error(self, error): """Handles errors received from the FCM API.""" - data = {} - try: - parsed_body = error.response.json() - if isinstance(parsed_body, dict): - data = parsed_body - except ValueError: - pass - - code, msg = _MessagingService._parse_fcm_error( - data, error.response.content, error.response.status_code) - raise ApiCallError(code, msg, error) + return _utils.handle_platform_error_from_requests( + error, _MessagingService._build_fcm_error_requests) def _handle_iid_error(self, error): """Handles errors received from the Instance ID API.""" @@ -489,38 +472,32 @@ def _handle_iid_error(self, error): error.response.status_code, error.response.content.decode()) raise ApiCallError(code, msg, error) - def _parse_batch_error(self, error): - """Parses a googleapiclient.http.HttpError content in to an ApiCallError.""" - if error.content is None: - msg = 'Failed to call messaging API: {0}'.format(error) - return ApiCallError(self.INTERNAL_ERROR, msg, error) + def _handle_batch_error(self, error): + """Handles errors received from the googleapiclient while making batch requests.""" + return _utils.handle_platform_error_from_googleapiclient( + error, _MessagingService._build_fcm_error_googleapiclient) - data = {} - try: - parsed_body = json.loads(error.content.decode()) - if isinstance(parsed_body, dict): - data = parsed_body - except ValueError: - pass + @classmethod + def _build_fcm_error_requests(cls, error, message, error_dict): + """Parses an error response from the FCM API and creates a FCM-specific exception if + appropriate.""" + exc_type = cls._build_fcm_error(error_dict) + return exc_type(message, cause=error, http_response=error.response) if exc_type else None - code, msg = _MessagingService._parse_fcm_error(data, error.content, error.resp.status) - return ApiCallError(code, msg, error) + @classmethod + def _build_fcm_error_googleapiclient(cls, error, message, error_dict, http_response): + """Parses an error response from the FCM API and creates a FCM-specific exception if + appropriate.""" + exc_type = cls._build_fcm_error(error_dict) + return exc_type(message, cause=error, http_response=http_response) if exc_type else None @classmethod - def _parse_fcm_error(cls, data, content, status_code): - """Parses an error response from the FCM API to a ApiCallError.""" - error_dict = data.get('error', {}) - server_code = None + def _build_fcm_error(cls, error_dict): + if not error_dict: + return None + fcm_code = None for detail in error_dict.get('details', []): if detail.get('@type') == 'type.googleapis.com/google.firebase.fcm.v1.FcmError': - server_code = detail.get('errorCode') + fcm_code = detail.get('errorCode') break - if not server_code: - server_code = error_dict.get('status') - code = _MessagingService.FCM_ERROR_CODES.get(server_code, _MessagingService.UNKNOWN_ERROR) - - msg = error_dict.get('message') - if not msg: - msg = 'Unexpected HTTP response with status: {0}; body: {1}'.format( - status_code, content.decode()) - return code, msg + return _MessagingService.FCM_ERROR_TYPES.get(fcm_code) diff --git a/integration/test_messaging.py b/integration/test_messaging.py index 7ebd5866a..ef5281523 100644 --- a/integration/test_messaging.py +++ b/integration/test_messaging.py @@ -16,6 +16,9 @@ import re +import pytest + +from firebase_admin import exceptions from firebase_admin import messaging @@ -47,6 +50,22 @@ def test_send(): msg_id = messaging.send(msg, dry_run=True) assert re.match('^projects/.*/messages/.*$', msg_id) +def test_send_invalid_token(): + msg = messaging.Message( + token=_REGISTRATION_TOKEN, + notification=messaging.Notification('test-title', 'test-body') + ) + with pytest.raises(messaging.SenderIdMismatchError): + messaging.send(msg, dry_run=True) + +def test_send_malformed_token(): + msg = messaging.Message( + token='not-a-token', + notification=messaging.Notification('test-title', 'test-body') + ) + with pytest.raises(exceptions.InvalidArgumentError): + messaging.send(msg, dry_run=True) + def test_send_all(): messages = [ messaging.Message( @@ -75,7 +94,7 @@ def test_send_all(): response = batch_response.responses[2] assert response.success is False - assert response.exception is not None + assert isinstance(response.exception, exceptions.InvalidArgumentError) assert response.message_id is None def test_send_one_hundred(): diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index f2897ab3c..98d9ce5e9 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -12,85 +12,324 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json +import socket +import httplib2 +import pytest import requests from requests import models +import six +from googleapiclient import errors from firebase_admin import exceptions from firebase_admin import _utils -def test_timeout_error(): - error = requests.exceptions.Timeout('Test error') - firebase_error = _utils.handle_requests_error(error) - assert isinstance(firebase_error, exceptions.DeadlineExceededError) - assert str(firebase_error) == 'Timed out while making an API call: Test error' - assert firebase_error.cause is error - assert firebase_error.http_response is None - -def test_connection_error(): - error = requests.exceptions.ConnectionError('Test error') - firebase_error = _utils.handle_requests_error(error) - assert isinstance(firebase_error, exceptions.UnavailableError) - assert str(firebase_error) == 'Failed to establish a connection: Test error' - assert firebase_error.cause is error - assert firebase_error.http_response is None - -def test_unknown_transport_error(): - error = requests.exceptions.RequestException('Test error') - firebase_error = _utils.handle_requests_error(error) - assert isinstance(firebase_error, exceptions.UnknownError) - assert str(firebase_error) == 'Unknown error while making a remote service call: Test error' - assert firebase_error.cause is error - assert firebase_error.http_response is None - -def test_http_response(): - resp = models.Response() - resp.status_code = 500 - error = requests.exceptions.RequestException('Test error', response=resp) - firebase_error = _utils.handle_requests_error(error) - assert isinstance(firebase_error, exceptions.InternalError) - assert str(firebase_error) == 'Test error' - assert firebase_error.cause is error - assert firebase_error.http_response is resp - -def test_http_response_with_unknown_status(): - resp = models.Response() - resp.status_code = 501 - error = requests.exceptions.RequestException('Test error', response=resp) - firebase_error = _utils.handle_requests_error(error) - assert isinstance(firebase_error, exceptions.UnknownError) - assert str(firebase_error) == 'Test error' - assert firebase_error.cause is error - assert firebase_error.http_response is resp - -def test_http_response_with_message(): - resp = models.Response() - resp.status_code = 500 - error = requests.exceptions.RequestException('Test error', response=resp) - firebase_error = _utils.handle_requests_error(error, message='Explicit error message') - assert isinstance(firebase_error, exceptions.InternalError) - assert str(firebase_error) == 'Explicit error message' - assert firebase_error.cause is error - assert firebase_error.http_response is resp - -def test_http_response_with_status(): - resp = models.Response() - resp.status_code = 500 - error = requests.exceptions.RequestException('Test error', response=resp) - firebase_error = _utils.handle_requests_error(error, status=503) - assert isinstance(firebase_error, exceptions.UnavailableError) - assert str(firebase_error) == 'Test error' - assert firebase_error.cause is error - assert firebase_error.http_response is resp - -def test_http_response_with_message_and_status(): - resp = models.Response() - resp.status_code = 500 - error = requests.exceptions.RequestException('Test error', response=resp) - firebase_error = _utils.handle_requests_error( - error, message='Explicit error message', status=503) - assert isinstance(firebase_error, exceptions.UnavailableError) - assert str(firebase_error) == 'Explicit error message' - assert firebase_error.cause is error - assert firebase_error.http_response is resp +_NOT_FOUND_ERROR_DICT = { + 'status': 'NOT_FOUND', + 'message': 'test error' +} + + +_NOT_FOUND_PAYLOAD = json.dumps({ + 'error': _NOT_FOUND_ERROR_DICT, +}) + + +class TestRequests(object): + + def test_timeout_error(self): + error = requests.exceptions.Timeout('Test error') + firebase_error = _utils.handle_requests_error(error) + assert isinstance(firebase_error, exceptions.DeadlineExceededError) + assert str(firebase_error) == 'Timed out while making an API call: Test error' + assert firebase_error.cause is error + assert firebase_error.http_response is None + + def test_requests_connection_error(self): + error = requests.exceptions.ConnectionError('Test error') + firebase_error = _utils.handle_requests_error(error) + assert isinstance(firebase_error, exceptions.UnavailableError) + assert str(firebase_error) == 'Failed to establish a connection: Test error' + assert firebase_error.cause is error + assert firebase_error.http_response is None + + def test_unknown_transport_error(self): + error = requests.exceptions.RequestException('Test error') + firebase_error = _utils.handle_requests_error(error) + assert isinstance(firebase_error, exceptions.UnknownError) + assert str(firebase_error) == 'Unknown error while making a remote service call: Test error' + assert firebase_error.cause is error + assert firebase_error.http_response is None + + def test_http_response(self): + resp, error = self._create_response() + firebase_error = _utils.handle_requests_error(error) + assert isinstance(firebase_error, exceptions.InternalError) + assert str(firebase_error) == 'Test error' + assert firebase_error.cause is error + assert firebase_error.http_response is resp + + def test_http_response_with_unknown_status(self): + resp, error = self._create_response(status=501) + firebase_error = _utils.handle_requests_error(error) + assert isinstance(firebase_error, exceptions.UnknownError) + assert str(firebase_error) == 'Test error' + assert firebase_error.cause is error + assert firebase_error.http_response is resp + + def test_http_response_with_message(self): + resp, error = self._create_response() + firebase_error = _utils.handle_requests_error(error, message='Explicit error message') + assert isinstance(firebase_error, exceptions.InternalError) + assert str(firebase_error) == 'Explicit error message' + assert firebase_error.cause is error + assert firebase_error.http_response is resp + + def test_http_response_with_code(self): + resp, error = self._create_response() + firebase_error = _utils.handle_requests_error(error, code=exceptions.UNAVAILABLE) + assert isinstance(firebase_error, exceptions.UnavailableError) + assert str(firebase_error) == 'Test error' + assert firebase_error.cause is error + assert firebase_error.http_response is resp + + def test_http_response_with_message_and_code(self): + resp, error = self._create_response() + firebase_error = _utils.handle_requests_error( + error, message='Explicit error message', code=exceptions.UNAVAILABLE) + assert isinstance(firebase_error, exceptions.UnavailableError) + assert str(firebase_error) == 'Explicit error message' + assert firebase_error.cause is error + assert firebase_error.http_response is resp + + def test_handle_platform_error(self): + resp, error = self._create_response(payload=_NOT_FOUND_PAYLOAD) + firebase_error = _utils.handle_platform_error_from_requests(error) + assert isinstance(firebase_error, exceptions.NotFoundError) + assert str(firebase_error) == 'test error' + assert firebase_error.cause is error + assert firebase_error.http_response is resp + + def test_handle_platform_error_with_no_response(self): + error = requests.exceptions.RequestException('Test error') + firebase_error = _utils.handle_platform_error_from_requests(error) + assert isinstance(firebase_error, exceptions.UnknownError) + assert str(firebase_error) == 'Unknown error while making a remote service call: Test error' + assert firebase_error.cause is error + assert firebase_error.http_response is None + + def test_handle_platform_error_with_no_error_code(self): + resp, error = self._create_response(payload='no error code') + firebase_error = _utils.handle_platform_error_from_requests(error) + assert isinstance(firebase_error, exceptions.InternalError) + message = 'Unexpected HTTP response with status: 500; body: no error code' + assert str(firebase_error) == message + assert firebase_error.cause is error + assert firebase_error.http_response is resp + + def test_handle_platform_error_with_custom_handler(self): + resp, error = self._create_response(payload=_NOT_FOUND_PAYLOAD) + invocations = [] + + def _custom_handler(cause, message, error_dict): + invocations.append((cause, message, error_dict)) + return exceptions.InvalidArgumentError('Custom message', cause, cause.response) + + firebase_error = _utils.handle_platform_error_from_requests(error, _custom_handler) + + assert isinstance(firebase_error, exceptions.InvalidArgumentError) + assert str(firebase_error) == 'Custom message' + assert firebase_error.cause is error + assert firebase_error.http_response is resp + assert len(invocations) == 1 + args = invocations[0] + assert len(args) == 3 + assert args[0] is error + assert args[1] == 'test error' + assert args[2] == _NOT_FOUND_ERROR_DICT + + def test_handle_platform_error_with_custom_handler_ignore(self): + resp, error = self._create_response(payload=_NOT_FOUND_PAYLOAD) + invocations = [] + + def _custom_handler(cause, message, error_dict): + invocations.append((cause, message, error_dict)) + return None + + firebase_error = _utils.handle_platform_error_from_requests(error, _custom_handler) + + assert isinstance(firebase_error, exceptions.NotFoundError) + assert str(firebase_error) == 'test error' + assert firebase_error.cause is error + assert firebase_error.http_response is resp + assert len(invocations) == 1 + args = invocations[0] + assert len(args) == 3 + assert args[0] is error + assert args[1] == 'test error' + assert args[2] == _NOT_FOUND_ERROR_DICT + + def _create_response(self, status=500, payload=None): + resp = models.Response() + resp.status_code = status + if payload: + resp.raw = six.BytesIO(payload.encode()) + exc = requests.exceptions.RequestException('Test error', response=resp) + return resp, exc + + +class TestGoogleApiClient(object): + + @pytest.mark.parametrize('error', [ + socket.timeout('Test error'), + socket.error('Read timed out') + ]) + def test_googleapicleint_timeout_error(self, error): + firebase_error = _utils.handle_googleapiclient_error(error) + assert isinstance(firebase_error, exceptions.DeadlineExceededError) + assert str(firebase_error) == 'Timed out while making an API call: {0}'.format(error) + assert firebase_error.cause is error + assert firebase_error.http_response is None + + def test_googleapiclient_connection_error(self): + error = httplib2.ServerNotFoundError('Test error') + firebase_error = _utils.handle_googleapiclient_error(error) + assert isinstance(firebase_error, exceptions.UnavailableError) + assert str(firebase_error) == 'Failed to establish a connection: Test error' + assert firebase_error.cause is error + assert firebase_error.http_response is None + + def test_unknown_transport_error(self): + error = socket.error('Test error') + firebase_error = _utils.handle_googleapiclient_error(error) + assert isinstance(firebase_error, exceptions.UnknownError) + assert str(firebase_error) == 'Unknown error while making a remote service call: Test error' + assert firebase_error.cause is error + assert firebase_error.http_response is None + + def test_http_response(self): + error = self._create_http_error() + firebase_error = _utils.handle_googleapiclient_error(error) + assert isinstance(firebase_error, exceptions.InternalError) + assert str(firebase_error) == str(error) + assert firebase_error.cause is error + assert firebase_error.http_response.status_code == 500 + assert firebase_error.http_response.content.decode() == 'Body' + + def test_http_response_with_unknown_status(self): + error = self._create_http_error(status=501) + firebase_error = _utils.handle_googleapiclient_error(error) + assert isinstance(firebase_error, exceptions.UnknownError) + assert str(firebase_error) == str(error) + assert firebase_error.cause is error + assert firebase_error.http_response.status_code == 501 + assert firebase_error.http_response.content.decode() == 'Body' + + def test_http_response_with_message(self): + error = self._create_http_error() + firebase_error = _utils.handle_googleapiclient_error( + error, message='Explicit error message') + assert isinstance(firebase_error, exceptions.InternalError) + assert str(firebase_error) == 'Explicit error message' + assert firebase_error.cause is error + assert firebase_error.http_response.status_code == 500 + assert firebase_error.http_response.content.decode() == 'Body' + + def test_http_response_with_code(self): + error = self._create_http_error() + firebase_error = _utils.handle_googleapiclient_error( + error, code=exceptions.UNAVAILABLE) + assert isinstance(firebase_error, exceptions.UnavailableError) + assert str(firebase_error) == str(error) + assert firebase_error.cause is error + assert firebase_error.http_response.status_code == 500 + assert firebase_error.http_response.content.decode() == 'Body' + + def test_http_response_with_message_and_code(self): + error = self._create_http_error() + firebase_error = _utils.handle_googleapiclient_error( + error, message='Explicit error message', code=exceptions.UNAVAILABLE) + assert isinstance(firebase_error, exceptions.UnavailableError) + assert str(firebase_error) == 'Explicit error message' + assert firebase_error.cause is error + assert firebase_error.http_response.status_code == 500 + assert firebase_error.http_response.content.decode() == 'Body' + + def test_handle_platform_error(self): + error = self._create_http_error(payload=_NOT_FOUND_PAYLOAD) + firebase_error = _utils.handle_platform_error_from_googleapiclient(error) + assert isinstance(firebase_error, exceptions.NotFoundError) + assert str(firebase_error) == 'test error' + assert firebase_error.cause is error + assert firebase_error.http_response.status_code == 500 + assert firebase_error.http_response.content.decode() == _NOT_FOUND_PAYLOAD + + def test_handle_platform_error_with_no_response(self): + error = socket.error('Test error') + firebase_error = _utils.handle_platform_error_from_googleapiclient(error) + assert isinstance(firebase_error, exceptions.UnknownError) + assert str(firebase_error) == 'Unknown error while making a remote service call: Test error' + assert firebase_error.cause is error + assert firebase_error.http_response is None + + def test_handle_platform_error_with_no_error_code(self): + error = self._create_http_error(payload='no error code') + firebase_error = _utils.handle_platform_error_from_googleapiclient(error) + assert isinstance(firebase_error, exceptions.InternalError) + message = 'Unexpected HTTP response with status: 500; body: no error code' + assert str(firebase_error) == message + assert firebase_error.cause is error + assert firebase_error.http_response.status_code == 500 + assert firebase_error.http_response.content.decode() == 'no error code' + + def test_handle_platform_error_with_custom_handler(self): + error = self._create_http_error(payload=_NOT_FOUND_PAYLOAD) + invocations = [] + + def _custom_handler(cause, message, error_dict, http_response): + invocations.append((cause, message, error_dict, http_response)) + return exceptions.InvalidArgumentError('Custom message', cause, http_response) + + firebase_error = _utils.handle_platform_error_from_googleapiclient(error, _custom_handler) + + assert isinstance(firebase_error, exceptions.InvalidArgumentError) + assert str(firebase_error) == 'Custom message' + assert firebase_error.cause is error + assert firebase_error.http_response.status_code == 500 + assert firebase_error.http_response.content.decode() == _NOT_FOUND_PAYLOAD + assert len(invocations) == 1 + args = invocations[0] + assert len(args) == 4 + assert args[0] is error + assert args[1] == 'test error' + assert args[2] == _NOT_FOUND_ERROR_DICT + assert args[3] is not None + + def test_handle_platform_error_with_custom_handler_ignore(self): + error = self._create_http_error(payload=_NOT_FOUND_PAYLOAD) + invocations = [] + + def _custom_handler(cause, message, error_dict, http_response): + invocations.append((cause, message, error_dict, http_response)) + return None + + firebase_error = _utils.handle_platform_error_from_googleapiclient(error, _custom_handler) + + assert isinstance(firebase_error, exceptions.NotFoundError) + assert str(firebase_error) == 'test error' + assert firebase_error.cause is error + assert firebase_error.http_response.status_code == 500 + assert firebase_error.http_response.content.decode() == _NOT_FOUND_PAYLOAD + assert len(invocations) == 1 + args = invocations[0] + assert len(args) == 4 + assert args[0] is error + assert args[1] == 'test error' + assert args[2] == _NOT_FOUND_ERROR_DICT + assert args[3] is not None + + def _create_http_error(self, status=500, payload='Body'): + resp = httplib2.Response({'status': status}) + return errors.HttpError(resp, payload.encode()) diff --git a/tests/test_messaging.py b/tests/test_messaging.py index de940b591..cf99c36ba 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -23,6 +23,7 @@ from googleapiclient.http import HttpMockSequence import firebase_admin +from firebase_admin import exceptions from firebase_admin import messaging from tests import testutils @@ -31,7 +32,20 @@ NON_DICT_ARGS = ['', list(), tuple(), True, False, 1, 0, {1: 'foo'}, {'foo': 1}] NON_OBJECT_ARGS = [list(), tuple(), dict(), 'foo', 0, 1, True, False] NON_LIST_ARGS = ['', tuple(), dict(), True, False, 1, 0, [1], ['foo', 1]] -HTTP_ERRORS = [400, 404, 500] +HTTP_ERRORS = [400, 404, 500] # TODO(hkj): Remove this when IID tests are updated. +HTTP_ERROR_CODES = { + 400: exceptions.InvalidArgumentError, + 404: exceptions.NotFoundError, + 500: exceptions.InternalError, + 503: exceptions.UnavailableError, +} +FCM_ERROR_CODES = { + 'APNS_AUTH_ERROR': messaging.ThirdPartyAuthError, + 'QUOTA_EXCEEDED': messaging.QuotaExceededError, + 'SENDER_ID_MISMATCH': messaging.SenderIdMismatchError, + 'THIRD_PARTY_AUTH_ERROR': messaging.ThirdPartyAuthError, + 'UNREGISTERED': messaging.UnregisteredError, +} def check_encoding(msg, expected=None): @@ -39,6 +53,13 @@ def check_encoding(msg, expected=None): if expected: assert encoded == expected +def check_exception(exception, message, status): + assert isinstance(exception, exceptions.FirebaseError) + assert str(exception) == message + assert exception.cause is not None + assert exception.http_response is not None + assert exception.http_response.status_code == status + class TestMulticastMessage(object): @@ -1258,15 +1279,14 @@ def test_send(self): body = {'message': messaging._MessagingService.encode_message(msg)} assert json.loads(recorder[0].body.decode()) == body - @pytest.mark.parametrize('status', HTTP_ERRORS) - def test_send_error(self, status): + @pytest.mark.parametrize('status,exc_type', HTTP_ERROR_CODES.items()) + def test_send_error(self, status, exc_type): _, recorder = self._instrument_messaging_service(status=status, payload='{}') msg = messaging.Message(topic='foo') - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(exc_type) as excinfo: messaging.send(msg) expected = 'Unexpected HTTP response with status: {0}; body: {{}}'.format(status) - assert str(excinfo.value) == expected - assert str(excinfo.value.code) == messaging._MessagingService.UNKNOWN_ERROR + check_exception(excinfo.value, expected, status) assert len(recorder) == 1 assert recorder[0].method == 'POST' assert recorder[0].url == self._get_url('explicit-project-id') @@ -1275,7 +1295,7 @@ def test_send_error(self, status): body = {'message': messaging._MessagingService.JSON_ENCODER.default(msg)} assert json.loads(recorder[0].body.decode()) == body - @pytest.mark.parametrize('status', HTTP_ERRORS) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_detailed_error(self, status): payload = json.dumps({ 'error': { @@ -1285,17 +1305,16 @@ def test_send_detailed_error(self, status): }) _, recorder = self._instrument_messaging_service(status=status, payload=payload) msg = messaging.Message(topic='foo') - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(exceptions.InvalidArgumentError) as excinfo: messaging.send(msg) - assert str(excinfo.value) == 'test error' - assert str(excinfo.value.code) == 'invalid-argument' + check_exception(excinfo.value, 'test error', status) assert len(recorder) == 1 assert recorder[0].method == 'POST' assert recorder[0].url == self._get_url('explicit-project-id') body = {'message': messaging._MessagingService.JSON_ENCODER.default(msg)} assert json.loads(recorder[0].body.decode()) == body - @pytest.mark.parametrize('status', HTTP_ERRORS) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_canonical_error_code(self, status): payload = json.dumps({ 'error': { @@ -1305,18 +1324,18 @@ def test_send_canonical_error_code(self, status): }) _, recorder = self._instrument_messaging_service(status=status, payload=payload) msg = messaging.Message(topic='foo') - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(exceptions.NotFoundError) as excinfo: messaging.send(msg) - assert str(excinfo.value) == 'test error' - assert str(excinfo.value.code) == 'registration-token-not-registered' + check_exception(excinfo.value, 'test error', status) assert len(recorder) == 1 assert recorder[0].method == 'POST' assert recorder[0].url == self._get_url('explicit-project-id') body = {'message': messaging._MessagingService.JSON_ENCODER.default(msg)} assert json.loads(recorder[0].body.decode()) == body - @pytest.mark.parametrize('status', HTTP_ERRORS) - def test_send_fcm_error_code(self, status): + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) + @pytest.mark.parametrize('fcm_error_code, exc_type', FCM_ERROR_CODES.items()) + def test_send_fcm_error_code(self, status, fcm_error_code, exc_type): payload = json.dumps({ 'error': { 'status': 'INVALID_ARGUMENT', @@ -1324,17 +1343,41 @@ def test_send_fcm_error_code(self, status): 'details': [ { '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', - 'errorCode': 'UNREGISTERED', + 'errorCode': fcm_error_code, }, ], } }) _, recorder = self._instrument_messaging_service(status=status, payload=payload) msg = messaging.Message(topic='foo') - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(exc_type) as excinfo: messaging.send(msg) - assert str(excinfo.value) == 'test error' - assert str(excinfo.value.code) == 'registration-token-not-registered' + check_exception(excinfo.value, 'test error', status) + assert len(recorder) == 1 + assert recorder[0].method == 'POST' + assert recorder[0].url == self._get_url('explicit-project-id') + body = {'message': messaging._MessagingService.JSON_ENCODER.default(msg)} + assert json.loads(recorder[0].body.decode()) == body + + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) + def test_send_unknown_fcm_error_code(self, status): + payload = json.dumps({ + 'error': { + 'status': 'INVALID_ARGUMENT', + 'message': 'test error', + 'details': [ + { + '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', + 'errorCode': 'SOME_UNKNOWN_CODE', + }, + ], + } + }) + _, recorder = self._instrument_messaging_service(status=status, payload=payload) + msg = messaging.Message(topic='foo') + with pytest.raises(exceptions.InvalidArgumentError) as excinfo: + messaging.send(msg) + check_exception(excinfo.value, 'test error', status) assert len(recorder) == 1 assert recorder[0].method == 'POST' assert recorder[0].url == self._get_url('explicit-project-id') @@ -1418,7 +1461,7 @@ def test_send_all(self): assert all([r.success for r in batch_response.responses]) assert not any([r.exception for r in batch_response.responses]) - @pytest.mark.parametrize('status', HTTP_ERRORS) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_all_detailed_error(self, status): success_payload = json.dumps({'name': 'message-id'}) error_payload = json.dumps({ @@ -1441,12 +1484,11 @@ def test_send_all_detailed_error(self, status): error_response = batch_response.responses[1] assert error_response.message_id is None assert error_response.success is False - assert error_response.exception is not None exception = error_response.exception - assert str(exception) == 'test error' - assert str(exception.code) == 'invalid-argument' + assert isinstance(exception, exceptions.InvalidArgumentError) + check_exception(exception, 'test error', status) - @pytest.mark.parametrize('status', HTTP_ERRORS) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_all_canonical_error_code(self, status): success_payload = json.dumps({'name': 'message-id'}) error_payload = json.dumps({ @@ -1469,13 +1511,13 @@ def test_send_all_canonical_error_code(self, status): error_response = batch_response.responses[1] assert error_response.message_id is None assert error_response.success is False - assert error_response.exception is not None exception = error_response.exception - assert str(exception) == 'test error' - assert str(exception.code) == 'registration-token-not-registered' + assert isinstance(exception, exceptions.NotFoundError) + check_exception(exception, 'test error', status) - @pytest.mark.parametrize('status', HTTP_ERRORS) - def test_send_all_fcm_error_code(self, status): + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) + @pytest.mark.parametrize('fcm_error_code, exc_type', FCM_ERROR_CODES.items()) + def test_send_all_fcm_error_code(self, status, fcm_error_code, exc_type): success_payload = json.dumps({'name': 'message-id'}) error_payload = json.dumps({ 'error': { @@ -1484,7 +1526,7 @@ def test_send_all_fcm_error_code(self, status): 'details': [ { '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', - 'errorCode': 'UNREGISTERED', + 'errorCode': fcm_error_code, }, ], } @@ -1503,22 +1545,20 @@ def test_send_all_fcm_error_code(self, status): error_response = batch_response.responses[1] assert error_response.message_id is None assert error_response.success is False - assert error_response.exception is not None exception = error_response.exception - assert str(exception) == 'test error' - assert str(exception.code) == 'registration-token-not-registered' + assert isinstance(exception, exc_type) + check_exception(exception, 'test error', status) - @pytest.mark.parametrize('status', HTTP_ERRORS) - def test_send_all_batch_error(self, status): + @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) + def test_send_all_batch_error(self, status, exc_type): _ = self._instrument_batch_messaging_service(status=status, payload='{}') msg = messaging.Message(topic='foo') - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(exc_type) as excinfo: messaging.send_all([msg]) expected = 'Unexpected HTTP response with status: {0}; body: {{}}'.format(status) - assert str(excinfo.value) == expected - assert str(excinfo.value.code) == messaging._MessagingService.UNKNOWN_ERROR + check_exception(excinfo.value, expected, status) - @pytest.mark.parametrize('status', HTTP_ERRORS) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_all_batch_detailed_error(self, status): payload = json.dumps({ 'error': { @@ -1528,12 +1568,11 @@ def test_send_all_batch_detailed_error(self, status): }) _ = self._instrument_batch_messaging_service(status=status, payload=payload) msg = messaging.Message(topic='foo') - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(exceptions.InvalidArgumentError) as excinfo: messaging.send_all([msg]) - assert str(excinfo.value) == 'test error' - assert str(excinfo.value.code) == 'invalid-argument' + check_exception(excinfo.value, 'test error', status) - @pytest.mark.parametrize('status', HTTP_ERRORS) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_all_batch_canonical_error_code(self, status): payload = json.dumps({ 'error': { @@ -1543,12 +1582,11 @@ def test_send_all_batch_canonical_error_code(self, status): }) _ = self._instrument_batch_messaging_service(status=status, payload=payload) msg = messaging.Message(topic='foo') - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(exceptions.NotFoundError) as excinfo: messaging.send_all([msg]) - assert str(excinfo.value) == 'test error' - assert str(excinfo.value.code) == 'registration-token-not-registered' + check_exception(excinfo.value, 'test error', status) - @pytest.mark.parametrize('status', HTTP_ERRORS) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_all_batch_fcm_error_code(self, status): payload = json.dumps({ 'error': { @@ -1564,10 +1602,9 @@ def test_send_all_batch_fcm_error_code(self, status): }) _ = self._instrument_batch_messaging_service(status=status, payload=payload) msg = messaging.Message(topic='foo') - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(messaging.UnregisteredError) as excinfo: messaging.send_all([msg]) - assert str(excinfo.value) == 'test error' - assert str(excinfo.value.code) == 'registration-token-not-registered' + check_exception(excinfo.value, 'test error', status) class TestSendMulticast(TestBatch): @@ -1599,7 +1636,7 @@ def test_send_multicast(self): assert all([r.success for r in batch_response.responses]) assert not any([r.exception for r in batch_response.responses]) - @pytest.mark.parametrize('status', HTTP_ERRORS) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_multicast_detailed_error(self, status): success_payload = json.dumps({'name': 'message-id'}) error_payload = json.dumps({ @@ -1624,10 +1661,10 @@ def test_send_multicast_detailed_error(self, status): assert error_response.success is False assert error_response.exception is not None exception = error_response.exception - assert str(exception) == 'test error' - assert str(exception.code) == 'invalid-argument' + assert isinstance(exception, exceptions.InvalidArgumentError) + check_exception(exception, 'test error', status) - @pytest.mark.parametrize('status', HTTP_ERRORS) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_multicast_canonical_error_code(self, status): success_payload = json.dumps({'name': 'message-id'}) error_payload = json.dumps({ @@ -1652,10 +1689,10 @@ def test_send_multicast_canonical_error_code(self, status): assert error_response.success is False assert error_response.exception is not None exception = error_response.exception - assert str(exception) == 'test error' - assert str(exception.code) == 'registration-token-not-registered' + assert isinstance(exception, exceptions.NotFoundError) + check_exception(exception, 'test error', status) - @pytest.mark.parametrize('status', HTTP_ERRORS) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_multicast_fcm_error_code(self, status): success_payload = json.dumps({'name': 'message-id'}) error_payload = json.dumps({ @@ -1686,20 +1723,19 @@ def test_send_multicast_fcm_error_code(self, status): assert error_response.success is False assert error_response.exception is not None exception = error_response.exception - assert str(exception) == 'test error' - assert str(exception.code) == 'registration-token-not-registered' + assert isinstance(exception, messaging.UnregisteredError) + check_exception(exception, 'test error', status) - @pytest.mark.parametrize('status', HTTP_ERRORS) - def test_send_multicast_batch_error(self, status): + @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) + def test_send_multicast_batch_error(self, status, exc_type): _ = self._instrument_batch_messaging_service(status=status, payload='{}') msg = messaging.MulticastMessage(tokens=['foo']) - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(exc_type) as excinfo: messaging.send_multicast(msg) expected = 'Unexpected HTTP response with status: {0}; body: {{}}'.format(status) - assert str(excinfo.value) == expected - assert str(excinfo.value.code) == messaging._MessagingService.UNKNOWN_ERROR + check_exception(excinfo.value, expected, status) - @pytest.mark.parametrize('status', HTTP_ERRORS) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_multicast_batch_detailed_error(self, status): payload = json.dumps({ 'error': { @@ -1709,12 +1745,11 @@ def test_send_multicast_batch_detailed_error(self, status): }) _ = self._instrument_batch_messaging_service(status=status, payload=payload) msg = messaging.MulticastMessage(tokens=['foo']) - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(exceptions.InvalidArgumentError) as excinfo: messaging.send_multicast(msg) - assert str(excinfo.value) == 'test error' - assert str(excinfo.value.code) == 'invalid-argument' + check_exception(excinfo.value, 'test error', status) - @pytest.mark.parametrize('status', HTTP_ERRORS) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_multicast_batch_canonical_error_code(self, status): payload = json.dumps({ 'error': { @@ -1724,12 +1759,11 @@ def test_send_multicast_batch_canonical_error_code(self, status): }) _ = self._instrument_batch_messaging_service(status=status, payload=payload) msg = messaging.MulticastMessage(tokens=['foo']) - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(exceptions.NotFoundError) as excinfo: messaging.send_multicast(msg) - assert str(excinfo.value) == 'test error' - assert str(excinfo.value.code) == 'registration-token-not-registered' + check_exception(excinfo.value, 'test error', status) - @pytest.mark.parametrize('status', HTTP_ERRORS) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_multicast_batch_fcm_error_code(self, status): payload = json.dumps({ 'error': { @@ -1745,10 +1779,9 @@ def test_send_multicast_batch_fcm_error_code(self, status): }) _ = self._instrument_batch_messaging_service(status=status, payload=payload) msg = messaging.MulticastMessage(tokens=['foo']) - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(messaging.UnregisteredError) as excinfo: messaging.send_multicast(msg) - assert str(excinfo.value) == 'test error' - assert str(excinfo.value.code) == 'registration-token-not-registered' + check_exception(excinfo.value, 'test error', status) class TestTopicManagement(object): From fa843f3aa2262e5fb643b11d29598c48213eb8ad Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Wed, 26 Jun 2019 16:23:15 -0700 Subject: [PATCH 03/37] Migrated remaining messaging APIs to new error types (#298) * Migrated FCM send APIs to the new error handling regime * Moved error parsing logic to _utils * Refactored OP error handling code * Fixing a broken test * Added utils for handling googleapiclient errors * Added tests for new error handling logic * Updated public API docs * Fixing test for python3 * Cleaning up the error code lookup code * Cleaning up the error parsing APIs * Cleaned up error parsing logic; Updated docs * Migrated the FCM IID APIs to the new error types --- firebase_admin/messaging.py | 49 +++++++++++-------------------------- tests/test_messaging.py | 38 ++++++++++------------------ 2 files changed, 27 insertions(+), 60 deletions(-) diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index 63cbbf4be..bfc611b26 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -36,7 +36,6 @@ 'AndroidNotification', 'APNSConfig', 'APNSPayload', - 'ApiCallError', 'Aps', 'ApsAlert', 'BatchResponse', @@ -45,8 +44,12 @@ 'Message', 'MulticastMessage', 'Notification', + 'QuotaExceededError', + 'SenderIdMismatchError', 'SendResponse', + 'ThirdPartyAuthError', 'TopicManagementResponse', + 'UnregisteredError', 'WebpushConfig', 'WebpushFcmOptions', 'WebpushNotification', @@ -167,7 +170,7 @@ def subscribe_to_topic(tokens, topic, app=None): TopicManagementResponse: A ``TopicManagementResponse`` instance. Raises: - ApiCallError: If an error occurs while communicating with instance ID service. + FirebaseError: If an error occurs while communicating with instance ID service. ValueError: If the input arguments are invalid. """ return _get_messaging_service(app).make_topic_management_request( @@ -186,7 +189,7 @@ def unsubscribe_from_topic(tokens, topic, app=None): TopicManagementResponse: A ``TopicManagementResponse`` instance. Raises: - ApiCallError: If an error occurs while communicating with instance ID service. + FirebaseError: If an error occurs while communicating with instance ID service. ValueError: If the input arguments are invalid. """ return _get_messaging_service(app).make_topic_management_request( @@ -243,21 +246,6 @@ def errors(self): return self._errors -class ApiCallError(Exception): - """Represents an Exception encountered while invoking the FCM API. - - Attributes: - code: A string error code. - message: A error message string. - detail: Original low-level exception. - """ - - def __init__(self, code, message, detail=None): - Exception.__init__(self, message) - self.code = code - self.detail = detail - - class BatchResponse(object): """The response received from a batch request to the FCM API.""" @@ -300,7 +288,7 @@ def success(self): @property def exception(self): - """An ApiCallError if an error occurs while sending the message to the FCM service.""" + """A FirebaseError if an error occurs while sending the message to the FCM service.""" return self._exception @@ -313,8 +301,6 @@ class _MessagingService(object): IID_HEADERS = {'access_token_auth': 'true'} JSON_ENCODER = _messaging_utils.MessageEncoder() - INTERNAL_ERROR = 'internal-error' - UNKNOWN_ERROR = 'unknown-error' FCM_ERROR_TYPES = { 'APNS_AUTH_ERROR': ThirdPartyAuthError, 'QUOTA_EXCEEDED': QuotaExceededError, @@ -322,13 +308,6 @@ class _MessagingService(object): 'THIRD_PARTY_AUTH_ERROR': ThirdPartyAuthError, 'UNREGISTERED': UnregisteredError, } - IID_ERROR_CODES = { - 400: 'invalid-argument', - 401: 'authentication-error', - 403: 'authentication-error', - 500: INTERNAL_ERROR, - 503: 'server-unavailable', - } def __init__(self, app): project_id = app.project_id @@ -431,10 +410,7 @@ def make_topic_management_request(self, tokens, topic, operation): timeout=self._timeout ) except requests.exceptions.RequestException as error: - if error.response is not None: - self._handle_iid_error(error) - else: - raise ApiCallError(self.INTERNAL_ERROR, 'Failed to call instance ID API.', error) + raise self._handle_iid_error(error) else: return TopicManagementResponse(resp) @@ -456,6 +432,9 @@ def _handle_fcm_error(self, error): def _handle_iid_error(self, error): """Handles errors received from the Instance ID API.""" + if error.response is None: + raise _utils.handle_requests_error(error) + data = {} try: parsed_body = error.response.json() @@ -464,13 +443,13 @@ def _handle_iid_error(self, error): except ValueError: pass - code = _MessagingService.IID_ERROR_CODES.get( - error.response.status_code, _MessagingService.UNKNOWN_ERROR) + # IID error response format: {"error": "some error message"} msg = data.get('error') if not msg: msg = 'Unexpected HTTP response with status: {0}; body: {1}'.format( error.response.status_code, error.response.content.decode()) - raise ApiCallError(code, msg, error) + + return _utils.handle_requests_error(error, msg) def _handle_batch_error(self, error): """Handles errors received from the googleapiclient while making batch requests.""" diff --git a/tests/test_messaging.py b/tests/test_messaging.py index cf99c36ba..421556da3 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -32,9 +32,9 @@ NON_DICT_ARGS = ['', list(), tuple(), True, False, 1, 0, {1: 'foo'}, {'foo': 1}] NON_OBJECT_ARGS = [list(), tuple(), dict(), 'foo', 0, 1, True, False] NON_LIST_ARGS = ['', tuple(), dict(), True, False, 1, 0, [1], ['foo', 1]] -HTTP_ERRORS = [400, 404, 500] # TODO(hkj): Remove this when IID tests are updated. HTTP_ERROR_CODES = { 400: exceptions.InvalidArgumentError, + 403: exceptions.PermissionDeniedError, 404: exceptions.NotFoundError, 500: exceptions.InternalError, 503: exceptions.UnavailableError, @@ -1859,30 +1859,24 @@ def test_subscribe_to_topic(self, args): assert recorder[0].url == self._get_url('iid/v1:batchAdd') assert json.loads(recorder[0].body.decode()) == args[2] - @pytest.mark.parametrize('status', HTTP_ERRORS) - def test_subscribe_to_topic_error(self, status): + @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) + def test_subscribe_to_topic_error(self, status, exc_type): _, recorder = self._instrument_iid_service( status=status, payload=self._DEFAULT_ERROR_RESPONSE) - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(exc_type) as excinfo: messaging.subscribe_to_topic('foo', 'test-topic') assert str(excinfo.value) == 'error_reason' - code = messaging._MessagingService.IID_ERROR_CODES.get( - status, messaging._MessagingService.UNKNOWN_ERROR) - assert excinfo.value.code == code assert len(recorder) == 1 assert recorder[0].method == 'POST' assert recorder[0].url == self._get_url('iid/v1:batchAdd') - @pytest.mark.parametrize('status', HTTP_ERRORS) - def test_subscribe_to_topic_non_json_error(self, status): + @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) + def test_subscribe_to_topic_non_json_error(self, status, exc_type): _, recorder = self._instrument_iid_service(status=status, payload='not json') - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(exc_type) as excinfo: messaging.subscribe_to_topic('foo', 'test-topic') reason = 'Unexpected HTTP response with status: {0}; body: not json'.format(status) - code = messaging._MessagingService.IID_ERROR_CODES.get( - status, messaging._MessagingService.UNKNOWN_ERROR) assert str(excinfo.value) == reason - assert excinfo.value.code == code assert len(recorder) == 1 assert recorder[0].method == 'POST' assert recorder[0].url == self._get_url('iid/v1:batchAdd') @@ -1897,30 +1891,24 @@ def test_unsubscribe_from_topic(self, args): assert recorder[0].url == self._get_url('iid/v1:batchRemove') assert json.loads(recorder[0].body.decode()) == args[2] - @pytest.mark.parametrize('status', HTTP_ERRORS) - def test_unsubscribe_from_topic_error(self, status): + @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) + def test_unsubscribe_from_topic_error(self, status, exc_type): _, recorder = self._instrument_iid_service( status=status, payload=self._DEFAULT_ERROR_RESPONSE) - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(exc_type) as excinfo: messaging.unsubscribe_from_topic('foo', 'test-topic') assert str(excinfo.value) == 'error_reason' - code = messaging._MessagingService.IID_ERROR_CODES.get( - status, messaging._MessagingService.UNKNOWN_ERROR) - assert excinfo.value.code == code assert len(recorder) == 1 assert recorder[0].method == 'POST' assert recorder[0].url == self._get_url('iid/v1:batchRemove') - @pytest.mark.parametrize('status', HTTP_ERRORS) - def test_unsubscribe_from_topic_non_json_error(self, status): + @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) + def test_unsubscribe_from_topic_non_json_error(self, status, exc_type): _, recorder = self._instrument_iid_service(status=status, payload='not json') - with pytest.raises(messaging.ApiCallError) as excinfo: + with pytest.raises(exc_type) as excinfo: messaging.unsubscribe_from_topic('foo', 'test-topic') reason = 'Unexpected HTTP response with status: {0}; body: not json'.format(status) - code = messaging._MessagingService.IID_ERROR_CODES.get( - status, messaging._MessagingService.UNKNOWN_ERROR) assert str(excinfo.value) == reason - assert excinfo.value.code == code assert len(recorder) == 1 assert recorder[0].method == 'POST' assert recorder[0].url == self._get_url('iid/v1:batchRemove') From b27216bfed58e8d1e9063e3892ea3ec9764e8e77 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Thu, 4 Jul 2019 19:56:08 -0700 Subject: [PATCH 04/37] Introducing TokenSignError to represent custom token creation errors (#302) * Migrated FCM send APIs to the new error handling regime * Moved error parsing logic to _utils * Refactored OP error handling code * Fixing a broken test * Added utils for handling googleapiclient errors * Added tests for new error handling logic * Updated public API docs * Fixing test for python3 * Cleaning up the error code lookup code * Cleaning up the error parsing APIs * Cleaned up error parsing logic; Updated docs * Migrated the FCM IID APIs to the new error types * Migrated custom token API to new error types --- firebase_admin/_token_gen.py | 16 ++++++++++++---- firebase_admin/auth.py | 9 ++++----- tests/test_token_gen.py | 15 +++++++++------ 3 files changed, 25 insertions(+), 15 deletions(-) diff --git a/firebase_admin/_token_gen.py b/firebase_admin/_token_gen.py index e2eaa5715..0fcb1d0c7 100644 --- a/firebase_admin/_token_gen.py +++ b/firebase_admin/_token_gen.py @@ -21,13 +21,15 @@ import requests import six from google.auth import credentials -from google.auth import exceptions from google.auth import iam from google.auth import jwt from google.auth import transport +import google.auth.exceptions import google.oauth2.id_token import google.oauth2.service_account +from firebase_admin import exceptions + # ID token constants ID_TOKEN_ISSUER_PREFIX = 'https://securetoken.google.com/' @@ -53,7 +55,6 @@ # Error codes COOKIE_CREATE_ERROR = 'COOKIE_CREATE_ERROR' -TOKEN_SIGN_ERROR = 'TOKEN_SIGN_ERROR' class ApiCallError(Exception): @@ -177,9 +178,9 @@ def create_custom_token(self, uid, developer_claims=None): payload['claims'] = developer_claims try: return jwt.encode(signing_provider.signer, payload) - except exceptions.TransportError as error: + except google.auth.exceptions.TransportError as error: msg = 'Failed to sign custom token. {0}'.format(error) - raise ApiCallError(TOKEN_SIGN_ERROR, msg, error) + raise TokenSignError(msg, error) def create_session_cookie(self, id_token, expires_in): @@ -339,3 +340,10 @@ def verify(self, token, request): certs_url=self.cert_url) verified_claims['uid'] = verified_claims['sub'] return verified_claims + + +class TokenSignError(exceptions.UnknownError): + """Unexpected error while signing a Firebase custom token.""" + + def __init__(self, message, cause): + exceptions.UnknownError.__init__(self, message, cause) diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index 0800d7c1e..5e168b2fe 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -42,6 +42,7 @@ 'ExportedUserRecord', 'ImportUserRecord', 'ListUsersPage', + 'TokenSignError', 'UserImportHash', 'UserImportResult', 'UserInfo', @@ -75,6 +76,7 @@ ListUsersPage = _user_mgt.ListUsersPage UserImportHash = _user_import.UserImportHash ImportUserRecord = _user_import.ImportUserRecord +TokenSignError = _token_gen.TokenSignError UserImportResult = _user_import.UserImportResult UserInfo = _user_mgt.UserInfo UserMetadata = _user_mgt.UserMetadata @@ -115,13 +117,10 @@ def create_custom_token(uid, developer_claims=None, app=None): Raises: ValueError: If input parameters are invalid. - AuthError: If an error occurs while creating the token using the remote IAM service. + TokenSignError: If an error occurs while signing the token using the remote IAM service. """ token_generator = _get_auth_service(app).token_generator - try: - return token_generator.create_custom_token(uid, developer_claims) - except _token_gen.ApiCallError as error: - raise AuthError(error.code, str(error), error.detail) + return token_generator.create_custom_token(uid, developer_claims) def verify_id_token(id_token, app=None, check_revoked=False): diff --git a/tests/test_token_gen.py b/tests/test_token_gen.py index 412ba3d0e..3a25640aa 100644 --- a/tests/test_token_gen.py +++ b/tests/test_token_gen.py @@ -21,8 +21,8 @@ import time from google.auth import crypt -from google.auth import exceptions from google.auth import jwt +import google.auth.exceptions import google.oauth2.id_token import pytest from pytest_localserver import plugin @@ -31,6 +31,7 @@ import firebase_admin from firebase_admin import auth from firebase_admin import credentials +from firebase_admin import exceptions from firebase_admin import _token_gen from tests import testutils @@ -219,10 +220,12 @@ def test_sign_with_iam_error(self): try: iam_resp = '{"error": {"code": 403, "message": "test error"}}' _overwrite_iam_request(app, testutils.MockRequest(403, iam_resp)) - with pytest.raises(auth.AuthError) as excinfo: + with pytest.raises(auth.TokenSignError) as excinfo: auth.create_custom_token(MOCK_UID, app=app) - assert excinfo.value.code == _token_gen.TOKEN_SIGN_ERROR - assert iam_resp in str(excinfo.value) + error = excinfo.value + assert error.code == exceptions.UNKNOWN + assert iam_resp in str(error) + assert isinstance(error.cause, google.auth.exceptions.TransportError) finally: firebase_admin.delete_app(app) @@ -421,7 +424,7 @@ def test_custom_token(self, auth_app): def test_certificate_request_failure(self, user_mgt_app): _overwrite_cert_request(user_mgt_app, testutils.MockRequest(404, 'not found')) - with pytest.raises(exceptions.TransportError): + with pytest.raises(google.auth.exceptions.TransportError): auth.verify_id_token(TEST_ID_TOKEN, app=user_mgt_app) @@ -521,7 +524,7 @@ def test_custom_token(self, auth_app): def test_certificate_request_failure(self, user_mgt_app): _overwrite_cert_request(user_mgt_app, testutils.MockRequest(404, 'not found')) - with pytest.raises(exceptions.TransportError): + with pytest.raises(google.auth.exceptions.TransportError): auth.verify_session_cookie(TEST_SESSION_COOKIE, app=user_mgt_app) From 99929ed8a4f88e9065ad4896e7e83c2981304a84 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Thu, 18 Jul 2019 15:18:07 -0700 Subject: [PATCH 05/37] Raising FirebaseError from create_session_cookie() API (#306) * Migrated FCM send APIs to the new error handling regime * Moved error parsing logic to _utils * Refactored OP error handling code * Fixing a broken test * Added utils for handling googleapiclient errors * Added tests for new error handling logic * Updated public API docs * Fixing test for python3 * Cleaning up the error code lookup code * Cleaning up the error parsing APIs * Cleaned up error parsing logic; Updated docs * Migrated the FCM IID APIs to the new error types * Migrated custom token API to new error types * Migrated create cookie API to new error types * Improved error message computation * Refactored the shared error handling code * Fixing lint errors * Renamed variable for clarity --- firebase_admin/_auth_utils.py | 71 ++++++++++++++++++++++++++++++++++ firebase_admin/_http_client.py | 4 ++ firebase_admin/_token_gen.py | 32 ++++----------- firebase_admin/auth.py | 10 ++--- integration/test_auth.py | 5 +++ tests/test_token_gen.py | 33 +++++++++++++--- 6 files changed, 120 insertions(+), 35 deletions(-) diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index b6788355c..d9b8c66e0 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -20,6 +20,9 @@ import six from six.moves import urllib +from firebase_admin import exceptions +from firebase_admin import _utils + MAX_CLAIMS_PAYLOAD_SIZE = 1000 RESERVED_CLAIMS = set([ @@ -188,3 +191,71 @@ def validate_action_type(action_type): raise ValueError('Invalid action type provided action_type: {0}. \ Valid values are {1}'.format(action_type, ', '.join(VALID_EMAIL_ACTION_TYPES))) return action_type + + +class InvalidIdTokenError(exceptions.InvalidArgumentError): + """The provided ID token is not a valid Firebase ID token.""" + + default_message = 'The provided ID token is invalid' + + def __init__(self, message, cause, http_response=None): + exceptions.InvalidArgumentError.__init__(self, message, cause, http_response) + + +class UnexpectedResponseError(exceptions.UnknownError): + """Backend service responded with an unexpected or malformed response.""" + + def __init__(self, message, cause=None, http_response=None): + exceptions.UnknownError.__init__(self, message, cause, http_response) + + +_CODE_TO_EXC_TYPE = { + 'INVALID_ID_TOKEN': InvalidIdTokenError, +} + + +def handle_auth_backend_error(error): + """Converts a requests error received from the Firebase Auth service into a FirebaseError.""" + if error.response is None: + raise _utils.handle_requests_error(error) + + code, custom_message = _parse_error_body(error.response) + if not code: + msg = 'Unexpected error response: {0}'.format(error.response.content.decode()) + raise _utils.handle_requests_error(error, message=msg) + + exc_type = _CODE_TO_EXC_TYPE.get(code) + msg = _build_error_message(code, exc_type, custom_message) + if not exc_type: + return _utils.handle_requests_error(error, message=msg) + + return exc_type(msg, cause=error, http_response=error.response) + + +def _parse_error_body(response): + """Parses the given error response to extract Auth error code and message.""" + error_dict = {} + try: + parsed_body = response.json() + if isinstance(parsed_body, dict): + error_dict = parsed_body.get('error', {}) + except ValueError: + pass + + # Auth error response format: {"error": {"message": "AUTH_ERROR_CODE: Optional text"}} + code = error_dict.get('message') + custom_message = None + if code: + separator = code.find(':') + if separator != -1: + custom_message = code[separator + 1:].strip() + code = code[:separator] + + return code, custom_message + + +def _build_error_message(code, exc_type, custom_message): + default_message = exc_type.default_message if ( + exc_type and hasattr(exc_type, 'default_message')) else 'Error while calling Auth service' + ext = ' {0}'.format(custom_message) if custom_message else '' + return '{0} ({1}).{2}'.format(default_message, code, ext) diff --git a/firebase_admin/_http_client.py b/firebase_admin/_http_client.py index 73028f833..eb8c4027a 100644 --- a/firebase_admin/_http_client.py +++ b/firebase_admin/_http_client.py @@ -109,6 +109,10 @@ def headers(self, method, url, **kwargs): resp = self.request(method, url, **kwargs) return resp.headers + def body_and_response(self, method, url, **kwargs): + resp = self.request(method, url, **kwargs) + return self.parse_body(resp), resp + def body(self, method, url, **kwargs): resp = self.request(method, url, **kwargs) return self.parse_body(resp) diff --git a/firebase_admin/_token_gen.py b/firebase_admin/_token_gen.py index 0fcb1d0c7..0ea34f77c 100644 --- a/firebase_admin/_token_gen.py +++ b/firebase_admin/_token_gen.py @@ -29,6 +29,7 @@ import google.oauth2.service_account from firebase_admin import exceptions +from firebase_admin import _auth_utils # ID token constants @@ -53,18 +54,6 @@ METADATA_SERVICE_URL = ('http://metadata/computeMetadata/v1/instance/service-accounts/' 'default/email') -# Error codes -COOKIE_CREATE_ERROR = 'COOKIE_CREATE_ERROR' - - -class ApiCallError(Exception): - """Represents an Exception encountered while invoking the ID toolkit API.""" - - def __init__(self, code, message, error=None): - Exception.__init__(self, message) - self.code = code - self.detail = error - class _SigningProvider(object): """Stores a reference to a google.auth.crypto.Signer.""" @@ -207,20 +196,15 @@ def create_session_cookie(self, id_token, expires_in): 'validDuration': expires_in, } try: - response = self.client.body('post', ':createSessionCookie', json=payload) + body, http_resp = self.client.body_and_response( + 'post', ':createSessionCookie', json=payload) except requests.exceptions.RequestException as error: - self._handle_http_error(COOKIE_CREATE_ERROR, 'Failed to create session cookie', error) - else: - if not response or not response.get('sessionCookie'): - raise ApiCallError(COOKIE_CREATE_ERROR, 'Failed to create session cookie.') - return response.get('sessionCookie') - - def _handle_http_error(self, code, msg, error): - if error.response is not None: - msg += '\nServer response: {0}'.format(error.response.content.decode()) + raise _auth_utils.handle_auth_backend_error(error) else: - msg += '\nReason: {0}'.format(error) - raise ApiCallError(code, msg, error) + if not body or not body.get('sessionCookie'): + raise _auth_utils.UnexpectedResponseError( + 'Failed to create session cookie.', http_response=http_resp) + return body.get('sessionCookie') class TokenVerifier(object): diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index 5e168b2fe..cd9a882e6 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -22,6 +22,7 @@ import time import firebase_admin +from firebase_admin import _auth_utils from firebase_admin import _http_client from firebase_admin import _token_gen from firebase_admin import _user_import @@ -76,7 +77,9 @@ ListUsersPage = _user_mgt.ListUsersPage UserImportHash = _user_import.UserImportHash ImportUserRecord = _user_import.ImportUserRecord +InvalidIdTokenError = _auth_utils.InvalidIdTokenError TokenSignError = _token_gen.TokenSignError +UnexpectedResponseError = _auth_utils.UnexpectedResponseError UserImportResult = _user_import.UserImportResult UserInfo = _user_mgt.UserInfo UserMetadata = _user_mgt.UserMetadata @@ -169,13 +172,10 @@ def create_session_cookie(id_token, expires_in, app=None): Raises: ValueError: If input parameters are invalid. - AuthError: If an error occurs while creating the cookie. + FirebaseError: If an error occurs while creating the cookie. """ token_generator = _get_auth_service(app).token_generator - try: - return token_generator.create_session_cookie(id_token, expires_in) - except _token_gen.ApiCallError as error: - raise AuthError(error.code, str(error), error.detail) + return token_generator.create_session_cookie(id_token, expires_in) def verify_session_cookie(session_cookie, check_revoked=False, app=None): diff --git a/integration/test_auth.py b/integration/test_auth.py index 53577b827..a2905d881 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -129,6 +129,11 @@ def test_session_cookies(api_key): estimated_exp = int(time.time() + expires_in.total_seconds()) assert abs(claims['exp'] - estimated_exp) < 5 +def test_session_cookie_error(): + expires_in = datetime.timedelta(days=1) + with pytest.raises(auth.InvalidIdTokenError): + auth.create_session_cookie('not.a.token', expires_in=expires_in) + def test_get_non_existing_user(): with pytest.raises(auth.AuthError) as excinfo: auth.get_user('non.existing') diff --git a/tests/test_token_gen.py b/tests/test_token_gen.py index 3a25640aa..baf8d9515 100644 --- a/tests/test_token_gen.py +++ b/tests/test_token_gen.py @@ -301,17 +301,38 @@ def test_valid_args(self, user_mgt_app, expires_in): assert request == {'idToken' : 'id_token', 'validDuration': 3600} def test_error(self, user_mgt_app): - _instrument_user_manager(user_mgt_app, 500, '{"error":"test"}') - with pytest.raises(auth.AuthError) as excinfo: + _instrument_user_manager(user_mgt_app, 500, '{"error":{"message": "INVALID_ID_TOKEN"}}') + with pytest.raises(auth.InvalidIdTokenError) as excinfo: + auth.create_session_cookie('id_token', expires_in=3600, app=user_mgt_app) + assert excinfo.value.code == exceptions.INVALID_ARGUMENT + assert str(excinfo.value) == 'The provided ID token is invalid (INVALID_ID_TOKEN).' + + def test_error_with_details(self, user_mgt_app): + _instrument_user_manager( + user_mgt_app, 500, '{"error":{"message": "INVALID_ID_TOKEN: More details."}}') + with pytest.raises(auth.InvalidIdTokenError) as excinfo: + auth.create_session_cookie('id_token', expires_in=3600, app=user_mgt_app) + assert excinfo.value.code == exceptions.INVALID_ARGUMENT + expected = 'The provided ID token is invalid (INVALID_ID_TOKEN). More details.' + assert str(excinfo.value) == expected + + def test_unexpected_error_code(self, user_mgt_app): + _instrument_user_manager(user_mgt_app, 500, '{"error":{"message": "SOMETHING_UNUSUAL"}}') + with pytest.raises(exceptions.InternalError) as excinfo: + auth.create_session_cookie('id_token', expires_in=3600, app=user_mgt_app) + assert str(excinfo.value) == 'Error while calling Auth service (SOMETHING_UNUSUAL).' + + def test_unexpected_error_response(self, user_mgt_app): + _instrument_user_manager(user_mgt_app, 500, '{}') + with pytest.raises(exceptions.InternalError) as excinfo: auth.create_session_cookie('id_token', expires_in=3600, app=user_mgt_app) - assert excinfo.value.code == _token_gen.COOKIE_CREATE_ERROR - assert '{"error":"test"}' in str(excinfo.value) + assert str(excinfo.value) == 'Unexpected error response: {}' def test_unexpected_response(self, user_mgt_app): _instrument_user_manager(user_mgt_app, 200, '{}') - with pytest.raises(auth.AuthError) as excinfo: + with pytest.raises(auth.UnexpectedResponseError) as excinfo: auth.create_session_cookie('id_token', expires_in=3600, app=user_mgt_app) - assert excinfo.value.code == _token_gen.COOKIE_CREATE_ERROR + assert excinfo.value.code == exceptions.UNKNOWN assert 'Failed to create session cookie' in str(excinfo.value) From 9fb0766600eadc11ef496cd8e0817133e2bbe81b Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Thu, 18 Jul 2019 16:56:19 -0700 Subject: [PATCH 06/37] Introducing UserNotFoundError type (#309) * Added UserNotFoundError type * Fixed some lint errors * Some formatting updates * Updated docs and tests --- firebase_admin/_auth_utils.py | 12 +++++- firebase_admin/_user_mgt.py | 21 ++++----- firebase_admin/auth.py | 35 ++++++--------- integration/test_auth.py | 11 +++-- tests/test_user_mgt.py | 81 +++++++++++++++++++++++++++++------ 5 files changed, 110 insertions(+), 50 deletions(-) diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index d9b8c66e0..7e992db06 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -209,8 +209,18 @@ def __init__(self, message, cause=None, http_response=None): exceptions.UnknownError.__init__(self, message, cause, http_response) +class UserNotFoundError(exceptions.NotFoundError): + """No user record found for the specified identifier.""" + + default_message = 'No user record found for the given identifier' + + def __init__(self, message, cause=None, http_response=None): + exceptions.NotFoundError.__init__(self, message, cause, http_response) + + _CODE_TO_EXC_TYPE = { 'INVALID_ID_TOKEN': InvalidIdTokenError, + 'USER_NOT_FOUND': UserNotFoundError, } @@ -243,7 +253,7 @@ def _parse_error_body(response): pass # Auth error response format: {"error": {"message": "AUTH_ERROR_CODE: Optional text"}} - code = error_dict.get('message') + code = error_dict.get('message') if isinstance(error_dict, dict) else None custom_message = None if code: separator = code.find(':') diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index 24bb2bdb6..a217d108c 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -24,8 +24,6 @@ from firebase_admin import _user_import -INTERNAL_ERROR = 'INTERNAL_ERROR' -USER_NOT_FOUND_ERROR = 'USER_NOT_FOUND_ERROR' USER_CREATE_ERROR = 'USER_CREATE_ERROR' USER_UPDATE_ERROR = 'USER_UPDATE_ERROR' USER_DELETE_ERROR = 'USER_DELETE_ERROR' @@ -381,6 +379,7 @@ def photo_url(self): def provider_id(self): return self._data.get('providerId') + class ActionCodeSettings(object): """Contains required continue/state URL with optional Android and iOS settings. Used when invoking the email action link generation APIs. @@ -396,6 +395,7 @@ def __init__(self, url, handle_code_in_app=None, dynamic_link_domain=None, ios_b self.android_install_app = android_install_app self.android_minimum_version = android_minimum_version + def encode_action_code_settings(settings): """ Validates the provided action code settings for email link generation and populates the REST api parameters. @@ -463,6 +463,7 @@ def encode_action_code_settings(settings): return parameters + class UserManager(object): """Provides methods for interacting with the Google Identity Toolkit.""" @@ -484,16 +485,16 @@ def get_user(self, **kwargs): raise TypeError('Unsupported keyword arguments: {0}.'.format(kwargs)) try: - response = self._client.body('post', '/accounts:lookup', json=payload) + body, http_resp = self._client.body_and_response( + 'post', '/accounts:lookup', json=payload) except requests.exceptions.RequestException as error: - msg = 'Failed to get user by {0}: {1}.'.format(key_type, key) - self._handle_http_error(INTERNAL_ERROR, msg, error) + raise _auth_utils.handle_auth_backend_error(error) else: - if not response or not response.get('users'): - raise ApiCallError( - USER_NOT_FOUND_ERROR, - 'No user record found for the provided {0}: {1}.'.format(key_type, key)) - return response['users'][0] + if not body or not body.get('users'): + raise _auth_utils.UserNotFoundError( + 'No user record found for the provided {0}: {1}.'.format(key_type, key), + http_response=http_resp) + return body['users'][0] def list_users(self, page_token=None, max_results=MAX_LIST_USERS_RESULTS): """Retrieves a batch of users.""" diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index cd9a882e6..f654eae42 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -48,6 +48,7 @@ 'UserImportResult', 'UserInfo', 'UserMetadata', + 'UserNotFoundError', 'UserProvider', 'UserRecord', @@ -83,6 +84,7 @@ UserImportResult = _user_import.UserImportResult UserInfo = _user_mgt.UserInfo UserMetadata = _user_mgt.UserMetadata +UserNotFoundError = _auth_utils.UserNotFoundError UserProvider = _user_import.UserProvider UserRecord = _user_mgt.UserRecord @@ -232,15 +234,12 @@ def get_user(uid, app=None): Raises: ValueError: If the user ID is None, empty or malformed. - AuthError: If an error occurs while retrieving the user or if the specified user ID - does not exist. + UserNotFoundError: If the specified user ID does not exist. + FirebaseError: If an error occurs while retrieving the user. """ user_manager = _get_auth_service(app).user_manager - try: - response = user_manager.get_user(uid=uid) - return UserRecord(response) - except _user_mgt.ApiCallError as error: - raise AuthError(error.code, str(error), error.detail) + response = user_manager.get_user(uid=uid) + return UserRecord(response) def get_user_by_email(email, app=None): @@ -255,15 +254,12 @@ def get_user_by_email(email, app=None): Raises: ValueError: If the email is None, empty or malformed. - AuthError: If an error occurs while retrieving the user or no user exists by the specified - email address. + UserNotFoundError: If no user exists by the specified email address. + FirebaseError: If an error occurs while retrieving the user. """ user_manager = _get_auth_service(app).user_manager - try: - response = user_manager.get_user(email=email) - return UserRecord(response) - except _user_mgt.ApiCallError as error: - raise AuthError(error.code, str(error), error.detail) + response = user_manager.get_user(email=email) + return UserRecord(response) def get_user_by_phone_number(phone_number, app=None): @@ -278,15 +274,12 @@ def get_user_by_phone_number(phone_number, app=None): Raises: ValueError: If the phone number is None, empty or malformed. - AuthError: If an error occurs while retrieving the user or no user exists by the specified - phone number. + UserNotFoundError: If no user exists by the specified phone number. + FirebaseError: If an error occurs while retrieving the user. """ user_manager = _get_auth_service(app).user_manager - try: - response = user_manager.get_user(phone_number=phone_number) - return UserRecord(response) - except _user_mgt.ApiCallError as error: - raise AuthError(error.code, str(error), error.detail) + response = user_manager.get_user(phone_number=phone_number) + return UserRecord(response) def list_users(page_token=None, max_results=_user_mgt.MAX_LIST_USERS_RESULTS, app=None): diff --git a/integration/test_auth.py b/integration/test_auth.py index a2905d881..3af7be288 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -29,6 +29,7 @@ import google.oauth2.credentials from google.auth import transport + _verify_token_url = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/verifyCustomToken' _verify_password_url = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/verifyPassword' _password_reset_url = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/resetPassword' @@ -135,14 +136,16 @@ def test_session_cookie_error(): auth.create_session_cookie('not.a.token', expires_in=expires_in) def test_get_non_existing_user(): - with pytest.raises(auth.AuthError) as excinfo: + with pytest.raises(auth.UserNotFoundError) as excinfo: auth.get_user('non.existing') - assert 'USER_NOT_FOUND_ERROR' in str(excinfo.value.code) + assert str(excinfo.value) == 'No user record found for the provided user ID: non.existing.' def test_get_non_existing_user_by_email(): - with pytest.raises(auth.AuthError) as excinfo: + with pytest.raises(auth.UserNotFoundError) as excinfo: auth.get_user_by_email('non.existing@definitely.non.existing') - assert 'USER_NOT_FOUND_ERROR' in str(excinfo.value.code) + error_msg = ('No user record found for the provided email: ' + 'non.existing@definitely.non.existing.') + assert str(excinfo.value) == error_msg def test_update_non_existing_user(): with pytest.raises(auth.AuthError) as excinfo: diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index 797e0ce59..951205621 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -21,6 +21,7 @@ import firebase_admin from firebase_admin import auth +from firebase_admin import exceptions from firebase_admin import _auth_utils from firebase_admin import _user_import from firebase_admin import _user_mgt @@ -211,30 +212,79 @@ def test_get_user_by_phone(self, user_mgt_app): def test_get_user_non_existing(self, user_mgt_app): _instrument_user_manager(user_mgt_app, 200, '{"users":[]}') - with pytest.raises(auth.AuthError) as excinfo: + with pytest.raises(auth.UserNotFoundError) as excinfo: auth.get_user('nonexistentuser', user_mgt_app) - assert excinfo.value.code == _user_mgt.USER_NOT_FOUND_ERROR + error_msg = 'No user record found for the provided user ID: nonexistentuser.' + assert excinfo.value.code == exceptions.NOT_FOUND + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is None + + def test_get_user_by_email_non_existing(self, user_mgt_app): + _instrument_user_manager(user_mgt_app, 200, '{"users":[]}') + with pytest.raises(auth.UserNotFoundError) as excinfo: + auth.get_user_by_email('nonexistent@user', user_mgt_app) + error_msg = 'No user record found for the provided email: nonexistent@user.' + assert excinfo.value.code == exceptions.NOT_FOUND + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is None + + def test_get_user_by_phone_non_existing(self, user_mgt_app): + _instrument_user_manager(user_mgt_app, 200, '{"users":[]}') + with pytest.raises(auth.UserNotFoundError) as excinfo: + auth.get_user_by_phone_number('+1234567890', user_mgt_app) + error_msg = 'No user record found for the provided phone number: +1234567890.' + assert excinfo.value.code == exceptions.NOT_FOUND + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is None def test_get_user_http_error(self, user_mgt_app): - _instrument_user_manager(user_mgt_app, 500, '{"error":"test"}') - with pytest.raises(auth.AuthError) as excinfo: + _instrument_user_manager(user_mgt_app, 500, '{"error":{"message": "USER_NOT_FOUND"}}') + with pytest.raises(auth.UserNotFoundError) as excinfo: auth.get_user('testuser', user_mgt_app) - assert excinfo.value.code == _user_mgt.INTERNAL_ERROR - assert '{"error":"test"}' in str(excinfo.value) + error_msg = 'No user record found for the given identifier (USER_NOT_FOUND).' + assert excinfo.value.code == exceptions.NOT_FOUND + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + def test_get_user_http_error_unexpected_code(self, user_mgt_app): + _instrument_user_manager(user_mgt_app, 500, '{"error":{"message": "UNEXPECTED_CODE"}}') + with pytest.raises(exceptions.InternalError) as excinfo: + auth.get_user('testuser', user_mgt_app) + assert str(excinfo.value) == 'Error while calling Auth service (UNEXPECTED_CODE).' + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + def test_get_user_http_error_malformed_response(self, user_mgt_app): + _instrument_user_manager(user_mgt_app, 500, '{"error": "UNEXPECTED_CODE"}') + with pytest.raises(exceptions.InternalError) as excinfo: + auth.get_user('testuser', user_mgt_app) + assert str(excinfo.value) == 'Unexpected error response: {"error": "UNEXPECTED_CODE"}' + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None def test_get_user_by_email_http_error(self, user_mgt_app): - _instrument_user_manager(user_mgt_app, 500, '{"error":"test"}') - with pytest.raises(auth.AuthError) as excinfo: + _instrument_user_manager(user_mgt_app, 500, '{"error":{"message": "USER_NOT_FOUND"}}') + with pytest.raises(auth.UserNotFoundError) as excinfo: auth.get_user_by_email('non.existent.user@example.com', user_mgt_app) - assert excinfo.value.code == _user_mgt.INTERNAL_ERROR - assert '{"error":"test"}' in str(excinfo.value) + error_msg = 'No user record found for the given identifier (USER_NOT_FOUND).' + assert excinfo.value.code == exceptions.NOT_FOUND + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None def test_get_user_by_phone_http_error(self, user_mgt_app): - _instrument_user_manager(user_mgt_app, 500, '{"error":"test"}') - with pytest.raises(auth.AuthError) as excinfo: + _instrument_user_manager(user_mgt_app, 500, '{"error":{"message": "USER_NOT_FOUND"}}') + with pytest.raises(auth.UserNotFoundError) as excinfo: auth.get_user_by_phone_number('+1234567890', user_mgt_app) - assert excinfo.value.code == _user_mgt.INTERNAL_ERROR - assert '{"error":"test"}' in str(excinfo.value) + error_msg = 'No user record found for the given identifier (USER_NOT_FOUND).' + assert excinfo.value.code == exceptions.NOT_FOUND + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None class TestCreateUser(object): @@ -718,6 +768,7 @@ def test_invalid_args(self, arg): with pytest.raises(ValueError): auth.UserMetadata(**arg) + class TestImportUserRecord(object): _INVALID_USERS = ( @@ -1003,6 +1054,7 @@ def test_revoke_refresh_tokens(self, user_mgt_app): assert int(request['validSince']) >= int(before_time) assert int(request['validSince']) <= int(after_time) + class TestActionCodeSetting(object): def test_valid_data(self): @@ -1047,6 +1099,7 @@ def test_encode_action_code_bad_data(self): with pytest.raises(AttributeError): _user_mgt.encode_action_code_settings({"foo":"bar"}) + class TestGenerateEmailActionLink(object): def test_email_verification_no_settings(self, user_mgt_app): From 8a0cf081a355115c5bcd9ff88a9ffd3458e53eb7 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Thu, 25 Jul 2019 17:06:05 -0700 Subject: [PATCH 07/37] New error handling support in create/update/delete user APIs (#311) * New error handling support in create/update/delete user APIs * Fixing some lint errors --- firebase_admin/_auth_utils.py | 10 +++++ firebase_admin/_user_mgt.py | 38 +++++++++--------- firebase_admin/auth.py | 35 +++++++---------- integration/test_auth.py | 12 ++---- lint.sh | 2 +- tests/test_user_mgt.py | 73 +++++++++++++++++++++++++++-------- 6 files changed, 104 insertions(+), 66 deletions(-) diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index 7e992db06..2dfa23e08 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -193,6 +193,15 @@ def validate_action_type(action_type): return action_type +class UidAlreadyExistsError(exceptions.AlreadyExistsError): + """The user with the provided uid already exists.""" + + default_message = 'The user with the provided uid already exists' + + def __init__(self, message, cause, http_response=None): + exceptions.AlreadyExistsError.__init__(self, message, cause, http_response) + + class InvalidIdTokenError(exceptions.InvalidArgumentError): """The provided ID token is not a valid Firebase ID token.""" @@ -219,6 +228,7 @@ def __init__(self, message, cause=None, http_response=None): _CODE_TO_EXC_TYPE = { + 'DUPLICATE_LOCAL_ID': UidAlreadyExistsError, 'INVALID_ID_TOKEN': InvalidIdTokenError, 'USER_NOT_FOUND': UserNotFoundError, } diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index a217d108c..61090da9b 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -24,9 +24,6 @@ from firebase_admin import _user_import -USER_CREATE_ERROR = 'USER_CREATE_ERROR' -USER_UPDATE_ERROR = 'USER_UPDATE_ERROR' -USER_DELETE_ERROR = 'USER_DELETE_ERROR' USER_IMPORT_ERROR = 'USER_IMPORT_ERROR' USER_DOWNLOAD_ERROR = 'LIST_USERS_ERROR' GENERATE_EMAIL_ACTION_LINK_ERROR = 'GENERATE_EMAIL_ACTION_LINK_ERROR' @@ -531,13 +528,14 @@ def create_user(self, uid=None, display_name=None, email=None, phone_number=None } payload = {k: v for k, v in payload.items() if v is not None} try: - response = self._client.body('post', '/accounts', json=payload) + body, http_resp = self._client.body_and_response('post', '/accounts', json=payload) except requests.exceptions.RequestException as error: - self._handle_http_error(USER_CREATE_ERROR, 'Failed to create new user.', error) + raise _auth_utils.handle_auth_backend_error(error) else: - if not response or not response.get('localId'): - raise ApiCallError(USER_CREATE_ERROR, 'Failed to create new user.') - return response.get('localId') + if not body or not body.get('localId'): + raise _auth_utils.UnexpectedResponseError( + 'Failed to create new user.', http_response=http_resp) + return body.get('localId') def update_user(self, uid, display_name=_UNSPECIFIED, email=None, phone_number=_UNSPECIFIED, photo_url=_UNSPECIFIED, password=None, disabled=None, email_verified=None, @@ -581,26 +579,28 @@ def update_user(self, uid, display_name=_UNSPECIFIED, email=None, phone_number=_ payload = {k: v for k, v in payload.items() if v is not None} try: - response = self._client.body('post', '/accounts:update', json=payload) + body, http_resp = self._client.body_and_response( + 'post', '/accounts:update', json=payload) except requests.exceptions.RequestException as error: - self._handle_http_error( - USER_UPDATE_ERROR, 'Failed to update user: {0}.'.format(uid), error) + raise _auth_utils.handle_auth_backend_error(error) else: - if not response or not response.get('localId'): - raise ApiCallError(USER_UPDATE_ERROR, 'Failed to update user: {0}.'.format(uid)) - return response.get('localId') + if not body or not body.get('localId'): + raise _auth_utils.UnexpectedResponseError( + 'Failed to update user: {0}.'.format(uid), http_response=http_resp) + return body.get('localId') def delete_user(self, uid): """Deletes the user identified by the specified user ID.""" _auth_utils.validate_uid(uid, required=True) try: - response = self._client.body('post', '/accounts:delete', json={'localId' : uid}) + body, http_resp = self._client.body_and_response( + 'post', '/accounts:delete', json={'localId' : uid}) except requests.exceptions.RequestException as error: - self._handle_http_error( - USER_DELETE_ERROR, 'Failed to delete user: {0}.'.format(uid), error) + raise _auth_utils.handle_auth_backend_error(error) else: - if not response or not response.get('kind'): - raise ApiCallError(USER_DELETE_ERROR, 'Failed to delete user: {0}.'.format(uid)) + if not body or not body.get('kind'): + raise _auth_utils.UnexpectedResponseError( + 'Failed to delete user: {0}.'.format(uid), http_response=http_resp) def import_users(self, users, hash_alg=None): """Imports the given list of users to Firebase Auth.""" diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index f654eae42..19eaa54f5 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -44,6 +44,8 @@ 'ImportUserRecord', 'ListUsersPage', 'TokenSignError', + 'UidAlreadyExistsError', + 'UnexpectedResponseError', 'UserImportHash', 'UserImportResult', 'UserInfo', @@ -80,6 +82,7 @@ ImportUserRecord = _user_import.ImportUserRecord InvalidIdTokenError = _auth_utils.InvalidIdTokenError TokenSignError = _token_gen.TokenSignError +UidAlreadyExistsError = _auth_utils.UidAlreadyExistsError UnexpectedResponseError = _auth_utils.UnexpectedResponseError UserImportResult = _user_import.UserImportResult UserInfo = _user_mgt.UserInfo @@ -333,15 +336,12 @@ def create_user(**kwargs): Raises: ValueError: If the specified user properties are invalid. - AuthError: If an error occurs while creating the user account. + FirebaseError: If an error occurs while creating the user account. """ app = kwargs.pop('app', None) user_manager = _get_auth_service(app).user_manager - try: - uid = user_manager.create_user(**kwargs) - return UserRecord(user_manager.get_user(uid=uid)) - except _user_mgt.ApiCallError as error: - raise AuthError(error.code, str(error), error.detail) + uid = user_manager.create_user(**kwargs) + return UserRecord(user_manager.get_user(uid=uid)) def update_user(uid, **kwargs): @@ -373,15 +373,12 @@ def update_user(uid, **kwargs): Raises: ValueError: If the specified user ID or properties are invalid. - AuthError: If an error occurs while updating the user account. + FirebaseError: If an error occurs while updating the user account. """ app = kwargs.pop('app', None) user_manager = _get_auth_service(app).user_manager - try: - user_manager.update_user(uid, **kwargs) - return UserRecord(user_manager.get_user(uid=uid)) - except _user_mgt.ApiCallError as error: - raise AuthError(error.code, str(error), error.detail) + user_manager.update_user(uid, **kwargs) + return UserRecord(user_manager.get_user(uid=uid)) def set_custom_user_claims(uid, custom_claims, app=None): @@ -402,13 +399,10 @@ def set_custom_user_claims(uid, custom_claims, app=None): Raises: ValueError: If the specified user ID or the custom claims are invalid. - AuthError: If an error occurs while updating the user account. + FirebaseError: If an error occurs while updating the user account. """ user_manager = _get_auth_service(app).user_manager - try: - user_manager.update_user(uid, custom_claims=custom_claims) - except _user_mgt.ApiCallError as error: - raise AuthError(error.code, str(error), error.detail) + user_manager.update_user(uid, custom_claims=custom_claims) def delete_user(uid, app=None): @@ -420,13 +414,10 @@ def delete_user(uid, app=None): Raises: ValueError: If the user ID is None, empty or malformed. - AuthError: If an error occurs while deleting the user account. + FirebaseError: If an error occurs while deleting the user account. """ user_manager = _get_auth_service(app).user_manager - try: - user_manager.delete_user(uid) - except _user_mgt.ApiCallError as error: - raise AuthError(error.code, str(error), error.detail) + user_manager.delete_user(uid) def import_users(users, hash_alg=None, app=None): diff --git a/integration/test_auth.py b/integration/test_auth.py index 3af7be288..c0149fd69 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -148,14 +148,12 @@ def test_get_non_existing_user_by_email(): assert str(excinfo.value) == error_msg def test_update_non_existing_user(): - with pytest.raises(auth.AuthError) as excinfo: + with pytest.raises(auth.UserNotFoundError): auth.update_user('non.existing') - assert 'USER_UPDATE_ERROR' in str(excinfo.value.code) def test_delete_non_existing_user(): - with pytest.raises(auth.AuthError) as excinfo: + with pytest.raises(auth.UserNotFoundError): auth.delete_user('non.existing') - assert 'USER_DELETE_ERROR' in str(excinfo.value.code) @pytest.fixture def new_user(): @@ -258,9 +256,8 @@ def test_create_user(new_user): assert user.user_metadata.creation_timestamp > 0 assert user.user_metadata.last_sign_in_timestamp is None assert len(user.provider_data) is 0 - with pytest.raises(auth.AuthError) as excinfo: + with pytest.raises(auth.UidAlreadyExistsError): auth.create_user(uid=new_user.uid) - assert excinfo.value.code == 'USER_CREATE_ERROR' def test_update_user(new_user): _, email = _random_id() @@ -329,9 +326,8 @@ def test_disable_user(new_user_with_params): def test_delete_user(): user = auth.create_user() auth.delete_user(user.uid) - with pytest.raises(auth.AuthError) as excinfo: + with pytest.raises(auth.UserNotFoundError): auth.get_user(user.uid) - assert excinfo.value.code == 'USER_NOT_FOUND_ERROR' def test_revoke_refresh_tokens(new_user): user = auth.get_user(new_user.uid) diff --git a/lint.sh b/lint.sh index 603b78f92..aeb37f741 100755 --- a/lint.sh +++ b/lint.sh @@ -20,7 +20,7 @@ function lintAllFiles () { } function lintChangedFiles () { - files=`git status -s $1 | grep -v "^D" | awk '{print $NF}' | grep .py$` + files=`git status -s $1 | (grep -v "^D") | awk '{print $NF}' | (grep .py$ || true)` for f in $files do echo "Running linter on $f" diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index 951205621..f2ee4b58c 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -351,11 +351,31 @@ def test_create_user_with_id(self, user_mgt_app): assert request == {'localId' : 'testuser'} def test_create_user_error(self, user_mgt_app): - _instrument_user_manager(user_mgt_app, 500, '{"error":"test"}') - with pytest.raises(auth.AuthError) as excinfo: + _instrument_user_manager(user_mgt_app, 500, '{"error": {"message": "UNEXPECTED_CODE"}}') + with pytest.raises(exceptions.InternalError) as excinfo: auth.create_user(app=user_mgt_app) - assert excinfo.value.code == _user_mgt.USER_CREATE_ERROR - assert '{"error":"test"}' in str(excinfo.value) + assert str(excinfo.value) == 'Error while calling Auth service (UNEXPECTED_CODE).' + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + def test_uid_already_exists(self, user_mgt_app): + _instrument_user_manager(user_mgt_app, 500, '{"error": {"message": "DUPLICATE_LOCAL_ID"}}') + with pytest.raises(auth.UidAlreadyExistsError) as excinfo: + auth.create_user(app=user_mgt_app) + assert isinstance(excinfo.value, exceptions.AlreadyExistsError) + assert str(excinfo.value) == ('The user with the provided uid already exists ' + '(DUPLICATE_LOCAL_ID).') + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + def test_create_user_unexpected_response(self, user_mgt_app): + _instrument_user_manager(user_mgt_app, 200, '{"error": "test"}') + with pytest.raises(auth.UnexpectedResponseError) as excinfo: + auth.create_user(app=user_mgt_app) + assert str(excinfo.value) == 'Failed to create new user.' + assert excinfo.value.http_response is not None + assert excinfo.value.cause is None + assert isinstance(excinfo.value, exceptions.UnknownError) class TestUpdateUser(object): @@ -462,11 +482,21 @@ def test_update_user_delete_fields(self, user_mgt_app): } def test_update_user_error(self, user_mgt_app): - _instrument_user_manager(user_mgt_app, 500, '{"error":"test"}') - with pytest.raises(auth.AuthError) as excinfo: + _instrument_user_manager(user_mgt_app, 500, '{"error": {"message": "UNEXPECTED_CODE"}}') + with pytest.raises(exceptions.InternalError) as excinfo: auth.update_user('user', app=user_mgt_app) - assert excinfo.value.code == _user_mgt.USER_UPDATE_ERROR - assert '{"error":"test"}' in str(excinfo.value) + assert str(excinfo.value) == 'Error while calling Auth service (UNEXPECTED_CODE).' + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + def test_update_user_unexpected_response(self, user_mgt_app): + _instrument_user_manager(user_mgt_app, 200, '{"error": "test"}') + with pytest.raises(auth.UnexpectedResponseError) as excinfo: + auth.update_user('user', app=user_mgt_app) + assert str(excinfo.value) == 'Failed to update user: user.' + assert excinfo.value.http_response is not None + assert excinfo.value.cause is None + assert isinstance(excinfo.value, exceptions.UnknownError) @pytest.mark.parametrize('arg', [1, 1.0]) def test_update_user_valid_since(self, user_mgt_app, arg): @@ -530,11 +560,12 @@ def test_set_custom_user_claims_none(self, user_mgt_app): assert request == {'localId' : 'testuser', 'customAttributes' : json.dumps({})} def test_set_custom_user_claims_error(self, user_mgt_app): - _instrument_user_manager(user_mgt_app, 500, '{"error":"test"}') - with pytest.raises(auth.AuthError) as excinfo: + _instrument_user_manager(user_mgt_app, 500, '{"error": {"message": "UNEXPECTED_CODE"}}') + with pytest.raises(exceptions.InternalError) as excinfo: auth.set_custom_user_claims('user', {}, app=user_mgt_app) - assert excinfo.value.code == _user_mgt.USER_UPDATE_ERROR - assert '{"error":"test"}' in str(excinfo.value) + assert str(excinfo.value) == 'Error while calling Auth service (UNEXPECTED_CODE).' + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None class TestDeleteUser(object): @@ -550,11 +581,21 @@ def test_delete_user(self, user_mgt_app): auth.delete_user('testuser', user_mgt_app) def test_delete_user_error(self, user_mgt_app): - _instrument_user_manager(user_mgt_app, 500, '{"error":"test"}') - with pytest.raises(auth.AuthError) as excinfo: + _instrument_user_manager(user_mgt_app, 500, '{"error": {"message": "UNEXPECTED_CODE"}}') + with pytest.raises(exceptions.InternalError) as excinfo: auth.delete_user('user', app=user_mgt_app) - assert excinfo.value.code == _user_mgt.USER_DELETE_ERROR - assert '{"error":"test"}' in str(excinfo.value) + assert str(excinfo.value) == 'Error while calling Auth service (UNEXPECTED_CODE).' + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + def test_delete_user_unexpected_response(self, user_mgt_app): + _instrument_user_manager(user_mgt_app, 200, '{"error": "test"}') + with pytest.raises(auth.UnexpectedResponseError) as excinfo: + auth.delete_user('user', app=user_mgt_app) + assert str(excinfo.value) == 'Failed to delete user: user.' + assert excinfo.value.http_response is not None + assert excinfo.value.cause is None + assert isinstance(excinfo.value, exceptions.UnknownError) class TestListUsers(object): From 29c8b7ac4cf849973f267fa4574573b8977854f7 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Wed, 31 Jul 2019 10:48:51 -0700 Subject: [PATCH 08/37] Error handling improvements in email action link APIs (#312) * New error handling support in create/update/delete user APIs * Fixing some lint errors * Error handling update in email action link APIs --- firebase_admin/_auth_utils.py | 12 +++++++++++- firebase_admin/_user_mgt.py | 14 +++++++------- firebase_admin/auth.py | 30 ++++++++++++------------------ tests/test_user_mgt.py | 30 +++++++++++++++++++++++++++--- 4 files changed, 57 insertions(+), 29 deletions(-) diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index 2dfa23e08..08b930ae6 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -198,10 +198,19 @@ class UidAlreadyExistsError(exceptions.AlreadyExistsError): default_message = 'The user with the provided uid already exists' - def __init__(self, message, cause, http_response=None): + def __init__(self, message, cause, http_response): exceptions.AlreadyExistsError.__init__(self, message, cause, http_response) +class InvalidDynamicLinkDomainError(exceptions.InvalidArgumentError): + """Dynamic link domain in ActionCodeSettings is not authorized.""" + + default_message = 'Dynamic link domain specified in ActionCodeSettings is not authorized' + + def __init__(self, message, cause, http_response): + exceptions.InvalidArgumentError.__init__(self, message, cause, http_response) + + class InvalidIdTokenError(exceptions.InvalidArgumentError): """The provided ID token is not a valid Firebase ID token.""" @@ -229,6 +238,7 @@ def __init__(self, message, cause=None, http_response=None): _CODE_TO_EXC_TYPE = { 'DUPLICATE_LOCAL_ID': UidAlreadyExistsError, + 'INVALID_DYNAMIC_LINK_DOMAIN': InvalidDynamicLinkDomainError, 'INVALID_ID_TOKEN': InvalidIdTokenError, 'USER_NOT_FOUND': UserNotFoundError, } diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index 61090da9b..3910f9690 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -26,7 +26,6 @@ USER_IMPORT_ERROR = 'USER_IMPORT_ERROR' USER_DOWNLOAD_ERROR = 'LIST_USERS_ERROR' -GENERATE_EMAIL_ACTION_LINK_ERROR = 'GENERATE_EMAIL_ACTION_LINK_ERROR' MAX_LIST_USERS_RESULTS = 1000 MAX_IMPORT_USERS_SIZE = 1000 @@ -654,14 +653,15 @@ def generate_email_action_link(self, action_type, email, action_code_settings=No payload.update(encode_action_code_settings(action_code_settings)) try: - response = self._client.body('post', '/accounts:sendOobCode', json=payload) + body, http_resp = self._client.body_and_response( + 'post', '/accounts:sendOobCode', json=payload) except requests.exceptions.RequestException as error: - self._handle_http_error(GENERATE_EMAIL_ACTION_LINK_ERROR, 'Failed to generate link.', - error) + raise _auth_utils.handle_auth_backend_error(error) else: - if not response or not response.get('oobLink'): - raise ApiCallError(GENERATE_EMAIL_ACTION_LINK_ERROR, 'Failed to generate link.') - return response.get('oobLink') + if not body or not body.get('oobLink'): + raise _auth_utils.UnexpectedResponseError( + 'Failed to generate email action link.', http_response=http_resp) + return body.get('oobLink') def _handle_http_error(self, code, msg, error): if error.response is not None: diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index 19eaa54f5..61d71ad9f 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -42,6 +42,8 @@ 'ErrorInfo', 'ExportedUserRecord', 'ImportUserRecord', + 'InvalidDynamicLinkDomainError', + 'InvalidIdTokenError', 'ListUsersPage', 'TokenSignError', 'UidAlreadyExistsError', @@ -80,6 +82,7 @@ ListUsersPage = _user_mgt.ListUsersPage UserImportHash = _user_import.UserImportHash ImportUserRecord = _user_import.ImportUserRecord +InvalidDynamicLinkDomainError = _auth_utils.InvalidDynamicLinkDomainError InvalidIdTokenError = _auth_utils.InvalidIdTokenError TokenSignError = _token_gen.TokenSignError UidAlreadyExistsError = _auth_utils.UidAlreadyExistsError @@ -465,14 +468,11 @@ def generate_password_reset_link(email, action_code_settings=None, app=None): Raises: ValueError: If the provided arguments are invalid - AuthError: If an error occurs while generating the link + FirebaseError: If an error occurs while generating the link """ user_manager = _get_auth_service(app).user_manager - try: - return user_manager.generate_email_action_link('PASSWORD_RESET', email, - action_code_settings=action_code_settings) - except _user_mgt.ApiCallError as error: - raise AuthError(error.code, str(error), error.detail) + return user_manager.generate_email_action_link( + 'PASSWORD_RESET', email, action_code_settings=action_code_settings) def generate_email_verification_link(email, action_code_settings=None, app=None): @@ -490,14 +490,11 @@ def generate_email_verification_link(email, action_code_settings=None, app=None) Raises: ValueError: If the provided arguments are invalid - AuthError: If an error occurs while generating the link + FirebaseError: If an error occurs while generating the link """ user_manager = _get_auth_service(app).user_manager - try: - return user_manager.generate_email_action_link('VERIFY_EMAIL', email, - action_code_settings=action_code_settings) - except _user_mgt.ApiCallError as error: - raise AuthError(error.code, str(error), error.detail) + return user_manager.generate_email_action_link( + 'VERIFY_EMAIL', email, action_code_settings=action_code_settings) def generate_sign_in_with_email_link(email, action_code_settings, app=None): @@ -515,14 +512,11 @@ def generate_sign_in_with_email_link(email, action_code_settings, app=None): Raises: ValueError: If the provided arguments are invalid - AuthError: If an error occurs while generating the link + FirebaseError: If an error occurs while generating the link """ user_manager = _get_auth_service(app).user_manager - try: - return user_manager.generate_email_action_link('EMAIL_SIGNIN', email, - action_code_settings=action_code_settings) - except _user_mgt.ApiCallError as error: - raise AuthError(error.code, str(error), error.detail) + return user_manager.generate_email_action_link( + 'EMAIL_SIGNIN', email, action_code_settings=action_code_settings) def _check_jwt_revoked(verified_claims, error_code, label, app): diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index f2ee4b58c..594de2d4c 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -1200,9 +1200,29 @@ def test_password_reset_with_settings(self, user_mgt_app): auth.generate_password_reset_link, ]) def test_api_call_failure(self, user_mgt_app, func): - _instrument_user_manager(user_mgt_app, 500, '{"error":"dummy error"}') - with pytest.raises(auth.AuthError): + _instrument_user_manager(user_mgt_app, 500, '{"error":{"message": "UNEXPECTED_CODE"}}') + with pytest.raises(exceptions.InternalError) as excinfo: + func('test@test.com', MOCK_ACTION_CODE_SETTINGS, app=user_mgt_app) + assert str(excinfo.value) == 'Error while calling Auth service (UNEXPECTED_CODE).' + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + @pytest.mark.parametrize('func', [ + auth.generate_sign_in_with_email_link, + auth.generate_email_verification_link, + auth.generate_password_reset_link, + ]) + def test_invalid_dynamic_link(self, user_mgt_app, func): + resp = '{"error":{"message": "INVALID_DYNAMIC_LINK_DOMAIN: Because of this reason."}}' + _instrument_user_manager(user_mgt_app, 500, resp) + with pytest.raises(auth.InvalidDynamicLinkDomainError) as excinfo: func('test@test.com', MOCK_ACTION_CODE_SETTINGS, app=user_mgt_app) + assert isinstance(excinfo.value, exceptions.InvalidArgumentError) + assert str(excinfo.value) == ('Dynamic link domain specified in ActionCodeSettings is ' + 'not authorized (INVALID_DYNAMIC_LINK_DOMAIN). Because ' + 'of this reason.') + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None @pytest.mark.parametrize('func', [ auth.generate_sign_in_with_email_link, @@ -1211,8 +1231,12 @@ def test_api_call_failure(self, user_mgt_app, func): ]) def test_api_call_no_link(self, user_mgt_app, func): _instrument_user_manager(user_mgt_app, 200, '{}') - with pytest.raises(auth.AuthError): + with pytest.raises(auth.UnexpectedResponseError) as excinfo: func('test@test.com', MOCK_ACTION_CODE_SETTINGS, app=user_mgt_app) + assert str(excinfo.value) == 'Failed to generate email action link.' + assert excinfo.value.http_response is not None + assert excinfo.value.cause is None + assert isinstance(excinfo.value, exceptions.UnknownError) @pytest.mark.parametrize('func', [ auth.generate_sign_in_with_email_link, From 33614522a5b34af8eafe18dc6431adf6cccbe6b2 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Wed, 31 Jul 2019 11:15:58 -0700 Subject: [PATCH 09/37] Project management API migrated to new error types (#314) --- firebase_admin/project_management.py | 104 +++++----------- integration/test_project_management.py | 29 +++-- tests/test_project_management.py | 158 +++++++++++++++---------- 3 files changed, 142 insertions(+), 149 deletions(-) diff --git a/firebase_admin/project_management.py b/firebase_admin/project_management.py index cc57471c5..075ee7a68 100644 --- a/firebase_admin/project_management.py +++ b/firebase_admin/project_management.py @@ -25,6 +25,7 @@ import six import firebase_admin +from firebase_admin import exceptions from firebase_admin import _http_client from firebase_admin import _utils @@ -139,21 +140,6 @@ def _check_not_none(obj, field_name): return obj -class ApiCallError(Exception): - """An error encountered while interacting with the Firebase Project Management Service.""" - - def __init__(self, message, error): - Exception.__init__(self, message) - self.detail = error - - -class _PollingError(Exception): - """An error encountered during the polling of an app's creation status.""" - - def __init__(self, message): - Exception.__init__(self, message) - - class AndroidApp(object): """A reference to an Android app within a Firebase project. @@ -185,7 +171,7 @@ def get_metadata(self): AndroidAppMetadata: An ``AndroidAppMetadata`` instance. Raises: - ApiCallError: If an error occurs while communicating with the Firebase Project + FirebaseError: If an error occurs while communicating with the Firebase Project Management Service. """ return self._service.get_android_app_metadata(self._app_id) @@ -200,7 +186,7 @@ def set_display_name(self, new_display_name): NoneType: None. Raises: - ApiCallError: If an error occurs while communicating with the Firebase Project + FirebaseError: If an error occurs while communicating with the Firebase Project Management Service. """ return self._service.set_android_app_display_name(self._app_id, new_display_name) @@ -216,7 +202,7 @@ def get_sha_certificates(self): list: A list of ``ShaCertificate`` instances. Raises: - ApiCallError: If an error occurs while communicating with the Firebase Project + FirebaseError: If an error occurs while communicating with the Firebase Project Management Service. """ return self._service.get_sha_certificates(self._app_id) @@ -231,7 +217,7 @@ def add_sha_certificate(self, certificate_to_add): NoneType: None. Raises: - ApiCallError: If an error occurs while communicating with the Firebase Project + FirebaseError: If an error occurs while communicating with the Firebase Project Management Service. (For example, if the certificate_to_add already exists.) """ return self._service.add_sha_certificate(self._app_id, certificate_to_add) @@ -246,7 +232,7 @@ def delete_sha_certificate(self, certificate_to_delete): NoneType: None. Raises: - ApiCallError: If an error occurs while communicating with the Firebase Project + FirebaseError: If an error occurs while communicating with the Firebase Project Management Service. (For example, if the certificate_to_delete is not found.) """ return self._service.delete_sha_certificate(certificate_to_delete) @@ -283,7 +269,7 @@ def get_metadata(self): IosAppMetadata: An ``IosAppMetadata`` instance. Raises: - ApiCallError: If an error occurs while communicating with the Firebase Project + FirebaseError: If an error occurs while communicating with the Firebase Project Management Service. """ return self._service.get_ios_app_metadata(self._app_id) @@ -298,7 +284,7 @@ def set_display_name(self, new_display_name): NoneType: None. Raises: - ApiCallError: If an error occurs while communicating with the Firebase Project + FirebaseError: If an error occurs while communicating with the Firebase Project Management Service. """ return self._service.set_ios_app_display_name(self._app_id, new_display_name) @@ -478,22 +464,11 @@ class _ProjectManagementService(object): MAXIMUM_POLLING_ATTEMPTS = 8 POLL_BASE_WAIT_TIME_SECONDS = 0.5 POLL_EXPONENTIAL_BACKOFF_FACTOR = 1.5 - ERROR_CODES = { - 401: 'Request not authorized.', - 403: 'Client does not have sufficient privileges.', - 404: 'Failed to find the resource.', - 409: 'The resource already exists.', - 429: 'Request throttled out by the backend server.', - 500: 'Internal server error.', - 503: 'Backend servers are over capacity. Try again later.' - } ANDROID_APPS_RESOURCE_NAME = 'androidApps' ANDROID_APP_IDENTIFIER_NAME = 'packageName' - ANDROID_APP_IDENTIFIER_LABEL = 'Package name' IOS_APPS_RESOURCE_NAME = 'iosApps' IOS_APP_IDENTIFIER_NAME = 'bundleId' - IOS_APP_IDENTIFIER_LABEL = 'Bundle ID' def __init__(self, app): project_id = app.project_id @@ -528,7 +503,7 @@ def _get_app_metadata(self, platform_resource_name, identifier_name, metadata_cl """Retrieves detailed information about an Android or iOS app.""" _check_is_nonempty_string(app_id, 'app_id') path = '/v1beta1/projects/-/{0}/{1}'.format(platform_resource_name, app_id) - response = self._make_request('get', path, app_id, 'App ID') + response = self._make_request('get', path) return metadata_class( response[identifier_name], name=response['name'], @@ -553,7 +528,7 @@ def _set_display_name(self, app_id, new_display_name, platform_resource_name): path = '/v1beta1/projects/-/{0}/{1}?updateMask=displayName'.format( platform_resource_name, app_id) request_body = {'displayName': new_display_name} - self._make_request('patch', path, app_id, 'App ID', json=request_body) + self._make_request('patch', path, json=request_body) def list_android_apps(self): return self._list_apps( @@ -571,7 +546,7 @@ def _list_apps(self, platform_resource_name, app_class): self._project_id, platform_resource_name, _ProjectManagementService.MAXIMUM_LIST_APPS_PAGE_SIZE) - response = self._make_request('get', path, self._project_id, 'Project ID') + response = self._make_request('get', path) apps_list = [] while True: apps = response.get('apps') @@ -587,14 +562,13 @@ def _list_apps(self, platform_resource_name, app_class): platform_resource_name, next_page_token, _ProjectManagementService.MAXIMUM_LIST_APPS_PAGE_SIZE) - response = self._make_request('get', path, self._project_id, 'Project ID') + response = self._make_request('get', path) return apps_list def create_android_app(self, package_name, display_name=None): return self._create_app( platform_resource_name=_ProjectManagementService.ANDROID_APPS_RESOURCE_NAME, identifier_name=_ProjectManagementService.ANDROID_APP_IDENTIFIER_NAME, - identifier_label=_ProjectManagementService.ANDROID_APP_IDENTIFIER_LABEL, identifier=package_name, display_name=display_name, app_class=AndroidApp) @@ -603,7 +577,6 @@ def create_ios_app(self, bundle_id, display_name=None): return self._create_app( platform_resource_name=_ProjectManagementService.IOS_APPS_RESOURCE_NAME, identifier_name=_ProjectManagementService.IOS_APP_IDENTIFIER_NAME, - identifier_label=_ProjectManagementService.IOS_APP_IDENTIFIER_LABEL, identifier=bundle_id, display_name=display_name, app_class=IosApp) @@ -612,7 +585,6 @@ def _create_app( self, platform_resource_name, identifier_name, - identifier_label, identifier, display_name, app_class): @@ -622,15 +594,10 @@ def _create_app( request_body = {identifier_name: identifier} if display_name: request_body['displayName'] = display_name - response = self._make_request('post', path, identifier, identifier_label, json=request_body) + response = self._make_request('post', path, json=request_body) operation_name = response['name'] - try: - poll_response = self._poll_app_creation(operation_name) - return app_class(app_id=poll_response['appId'], service=self) - except _PollingError as error: - raise ApiCallError( - _ProjectManagementService._extract_message(operation_name, 'Operation name', error), - error) + poll_response = self._poll_app_creation(operation_name) + return app_class(app_id=poll_response['appId'], service=self) def _poll_app_creation(self, operation_name): """Polls the Long-Running Operation repeatedly until it is done with exponential backoff.""" @@ -640,16 +607,17 @@ def _poll_app_creation(self, operation_name): wait_time_seconds = delay_factor * _ProjectManagementService.POLL_BASE_WAIT_TIME_SECONDS time.sleep(wait_time_seconds) path = '/v1/{0}'.format(operation_name) - poll_response = self._make_request('get', path, operation_name, 'Operation name') + poll_response, http_response = self._body_and_response('get', path) done = poll_response.get('done') if done: response = poll_response.get('response') if response: return response else: - raise _PollingError( - 'Polling finished, but the operation terminated in an error.') - raise _PollingError('Polling deadline exceeded.') + raise exceptions.UnknownError( + 'Polling finished, but the operation terminated in an error.', + http_response=http_response) + raise exceptions.DeadlineExceededError('Polling deadline exceeded.') def get_android_app_config(self, app_id): return self._get_app_config( @@ -662,14 +630,14 @@ def get_ios_app_config(self, app_id): def _get_app_config(self, platform_resource_name, app_id): path = '/v1beta1/projects/-/{0}/{1}/config'.format(platform_resource_name, app_id) - response = self._make_request('get', path, app_id, 'App ID') + response = self._make_request('get', path) # In Python 2.7, the base64 module works with strings, while in Python 3, it works with # bytes objects. This line works in both versions. return base64.standard_b64decode(response['configFileContents']).decode(encoding='utf-8') def get_sha_certificates(self, app_id): path = '/v1beta1/projects/-/androidApps/{0}/sha'.format(app_id) - response = self._make_request('get', path, app_id, 'App ID') + response = self._make_request('get', path) cert_list = response.get('certificates') or [] return [ShaCertificate(sha_hash=cert['shaHash'], name=cert['name']) for cert in cert_list] @@ -678,28 +646,20 @@ def add_sha_certificate(self, app_id, certificate_to_add): sha_hash = _check_not_none(certificate_to_add, 'certificate_to_add').sha_hash cert_type = certificate_to_add.cert_type request_body = {'shaHash': sha_hash, 'certType': cert_type} - self._make_request('post', path, app_id, 'App ID', json=request_body) + self._make_request('post', path, json=request_body) def delete_sha_certificate(self, certificate_to_delete): name = _check_not_none(certificate_to_delete, 'certificate_to_delete').name path = '/v1beta1/{0}'.format(name) - self._make_request('delete', path, name, 'SHA ID') + self._make_request('delete', path) + + def _make_request(self, method, url, json=None): + body, _ = self._body_and_response(method, url, json) + return body - def _make_request(self, method, url, resource_identifier, resource_identifier_label, json=None): + def _body_and_response(self, method, url, json=None): try: - return self._client.body(method=method, url=url, json=json, timeout=self._timeout) + return self._client.body_and_response( + method=method, url=url, json=json, timeout=self._timeout) except requests.exceptions.RequestException as error: - raise ApiCallError( - _ProjectManagementService._extract_message( - resource_identifier, resource_identifier_label, error), - error) - - @staticmethod - def _extract_message(identifier, identifier_label, error): - if not isinstance(error, requests.exceptions.RequestException) or error.response is None: - return '{0} "{1}": {2}'.format(identifier_label, identifier, str(error)) - status = error.response.status_code - message = _ProjectManagementService.ERROR_CODES.get(status) - if message: - return '{0} "{1}": {2}'.format(identifier_label, identifier, message) - return '{0} "{1}": Error {2}.'.format(identifier_label, identifier, status) + raise _utils.handle_platform_error_from_requests(error) diff --git a/integration/test_project_management.py b/integration/test_project_management.py index 7386a4837..7aa182a42 100644 --- a/integration/test_project_management.py +++ b/integration/test_project_management.py @@ -20,6 +20,7 @@ import pytest +from firebase_admin import exceptions from firebase_admin import project_management @@ -64,11 +65,12 @@ def ios_app(default_app): def test_create_android_app_already_exists(android_app): del android_app - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.AlreadyExistsError) as excinfo: project_management.create_android_app( package_name=TEST_APP_PACKAGE_NAME, display_name=TEST_APP_DISPLAY_NAME_PREFIX) - assert 'The resource already exists' in str(excinfo.value) - assert excinfo.value.detail is not None + assert 'Requested entity already exists' in str(excinfo.value) + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None def test_android_set_display_name_and_get_metadata(android_app, project_id): @@ -133,10 +135,11 @@ def test_android_sha_certificates(android_app): assert cert.name # Adding the same cert twice should cause an already-exists error. - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.AlreadyExistsError) as excinfo: android_app.add_sha_certificate(project_management.ShaCertificate(SHA_256_HASH_2)) - assert 'The resource already exists' in str(excinfo.value) - assert excinfo.value.detail is not None + assert 'Requested entity already exists' in str(excinfo.value) + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None # Delete all certs and assert that they have all been deleted successfully. for cert in cert_list: @@ -145,20 +148,22 @@ def test_android_sha_certificates(android_app): assert android_app.get_sha_certificates() == [] # Deleting a nonexistent cert should cause a not-found error. - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.NotFoundError) as excinfo: android_app.delete_sha_certificate(cert_list[0]) - assert 'Failed to find the resource' in str(excinfo.value) - assert excinfo.value.detail is not None + assert 'Requested entity was not found' in str(excinfo.value) + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None def test_create_ios_app_already_exists(ios_app): del ios_app - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.AlreadyExistsError) as excinfo: project_management.create_ios_app( bundle_id=TEST_APP_BUNDLE_ID, display_name=TEST_APP_DISPLAY_NAME_PREFIX) - assert 'The resource already exists' in str(excinfo.value) - assert excinfo.value.detail is not None + assert 'Requested entity already exists' in str(excinfo.value) + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None def test_ios_set_display_name_and_get_metadata(ios_app, project_id): diff --git a/tests/test_project_management.py b/tests/test_project_management.py index 9de95f7fd..b139a73c5 100644 --- a/tests/test_project_management.py +++ b/tests/test_project_management.py @@ -20,6 +20,7 @@ import pytest import firebase_admin +from firebase_admin import exceptions from firebase_admin import project_management from tests import testutils @@ -196,6 +197,10 @@ display_name='My iOS App', project_id='test-project-id') +ALREADY_EXISTS_RESPONSE = ('{"error": {"status": "ALREADY_EXISTS", ' + '"message": "The resource already exists"}}') +NOT_FOUND_RESPONSE = '{"error": {"message": "Failed to find the resource"}}' +UNAVAILABLE_RESPONSE = '{"error": {"message": "Backend servers are over capacity"}}' class TestAndroidAppMetadata(object): @@ -578,15 +583,16 @@ def test_create_android_app(self): recorder[2], 'GET', 'https://firebase.googleapis.com/v1/operations/abcdefg') def test_create_android_app_already_exists(self): - recorder = self._instrument_service(statuses=[409], responses=['some error response']) + recorder = self._instrument_service(statuses=[409], responses=[ALREADY_EXISTS_RESPONSE]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.AlreadyExistsError) as excinfo: project_management.create_android_app( package_name='com.hello.world.android', display_name='My Android App') assert 'The resource already exists' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_create_android_app_polling_rpc_error(self): @@ -595,16 +601,17 @@ def test_create_android_app_polling_rpc_error(self): responses=[ OPERATION_IN_PROGRESS_RESPONSE, # Request to create Android app asynchronously. OPERATION_IN_PROGRESS_RESPONSE, # Creation operation is still not done. - 'some error response', # Error 503. + UNAVAILABLE_RESPONSE, # Error 503. ]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnavailableError) as excinfo: project_management.create_android_app( package_name='com.hello.world.android', display_name='My Android App') assert 'Backend servers are over capacity' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 3 def test_create_android_app_polling_failure(self): @@ -616,13 +623,14 @@ def test_create_android_app_polling_failure(self): OPERATION_FAILED_RESPONSE, # Operation is finished, but terminated with an error. ]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnknownError) as excinfo: project_management.create_android_app( package_name='com.hello.world.android', display_name='My Android App') assert 'Polling finished, but the operation terminated in an error' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is None + assert excinfo.value.http_response is not None assert len(recorder) == 3 def test_create_android_app_polling_limit_exceeded(self): @@ -635,13 +643,13 @@ def test_create_android_app_polling_limit_exceeded(self): OPERATION_IN_PROGRESS_RESPONSE, # Creation Operation is still not done. ]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.DeadlineExceededError) as excinfo: project_management.create_android_app( package_name='com.hello.world.android', display_name='My Android App') assert 'Polling deadline exceeded' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is None assert len(recorder) == 3 @@ -695,15 +703,16 @@ def test_create_ios_app(self): recorder[2], 'GET', 'https://firebase.googleapis.com/v1/operations/abcdefg') def test_create_ios_app_already_exists(self): - recorder = self._instrument_service(statuses=[409], responses=['some error response']) + recorder = self._instrument_service(statuses=[409], responses=[ALREADY_EXISTS_RESPONSE]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.AlreadyExistsError) as excinfo: project_management.create_ios_app( bundle_id='com.hello.world.ios', display_name='My iOS App') assert 'The resource already exists' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_create_ios_app_polling_rpc_error(self): @@ -712,16 +721,17 @@ def test_create_ios_app_polling_rpc_error(self): responses=[ OPERATION_IN_PROGRESS_RESPONSE, # Request to create iOS app asynchronously. OPERATION_IN_PROGRESS_RESPONSE, # Creation operation is still not done. - 'some error response', # Error 503. + UNAVAILABLE_RESPONSE, # Error 503. ]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnavailableError) as excinfo: project_management.create_ios_app( bundle_id='com.hello.world.ios', display_name='My iOS App') assert 'Backend servers are over capacity' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 3 def test_create_ios_app_polling_failure(self): @@ -733,13 +743,14 @@ def test_create_ios_app_polling_failure(self): OPERATION_FAILED_RESPONSE, # Operation is finished, but terminated with an error. ]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnknownError) as excinfo: project_management.create_ios_app( bundle_id='com.hello.world.ios', display_name='My iOS App') assert 'Polling finished, but the operation terminated in an error' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is None + assert excinfo.value.http_response is not None assert len(recorder) == 3 def test_create_ios_app_polling_limit_exceeded(self): @@ -752,13 +763,13 @@ def test_create_ios_app_polling_limit_exceeded(self): OPERATION_IN_PROGRESS_RESPONSE, # Creation Operation is still not done. ]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.DeadlineExceededError) as excinfo: project_management.create_ios_app( bundle_id='com.hello.world.ios', display_name='My iOS App') assert 'Polling deadline exceeded' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is None assert len(recorder) == 3 @@ -779,13 +790,14 @@ def test_list_android_apps(self): self._assert_request_is_correct(recorder[0], 'GET', TestListAndroidApps._LISTING_URL) def test_list_android_apps_rpc_error(self): - recorder = self._instrument_service(statuses=[503], responses=['some error response']) + recorder = self._instrument_service(statuses=[503], responses=[UNAVAILABLE_RESPONSE]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnavailableError) as excinfo: project_management.list_android_apps() assert 'Backend servers are over capacity' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_list_android_apps_empty_list(self): @@ -813,13 +825,14 @@ def test_list_android_apps_multiple_pages(self): def test_list_android_apps_multiple_pages_rpc_error(self): recorder = self._instrument_service( statuses=[200, 503], - responses=[LIST_ANDROID_APPS_PAGE_1_RESPONSE, 'some error response']) + responses=[LIST_ANDROID_APPS_PAGE_1_RESPONSE, UNAVAILABLE_RESPONSE]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnavailableError) as excinfo: project_management.list_android_apps() assert 'Backend servers are over capacity' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 2 @@ -840,13 +853,14 @@ def test_list_ios_apps(self): self._assert_request_is_correct(recorder[0], 'GET', TestListIosApps._LISTING_URL) def test_list_ios_apps_rpc_error(self): - recorder = self._instrument_service(statuses=[503], responses=['some error response']) + recorder = self._instrument_service(statuses=[503], responses=[UNAVAILABLE_RESPONSE]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnavailableError) as excinfo: project_management.list_ios_apps() assert 'Backend servers are over capacity' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_list_ios_apps_empty_list(self): @@ -874,13 +888,14 @@ def test_list_ios_apps_multiple_pages(self): def test_list_ios_apps_multiple_pages_rpc_error(self): recorder = self._instrument_service( statuses=[200, 503], - responses=[LIST_IOS_APPS_PAGE_1_RESPONSE, 'some error response']) + responses=[LIST_IOS_APPS_PAGE_1_RESPONSE, UNAVAILABLE_RESPONSE]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnavailableError) as excinfo: project_management.list_ios_apps() assert 'Backend servers are over capacity' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 2 @@ -936,21 +951,24 @@ def test_get_metadata_unknown_error(self, android_app): recorder = self._instrument_service( statuses=[428], responses=['precondition required error']) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnknownError) as excinfo: android_app.get_metadata() - assert 'Error 428' in str(excinfo.value) - assert excinfo.value.detail is not None + message = 'Unexpected HTTP response with status: 428; body: precondition required error' + assert str(excinfo.value) == message + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_get_metadata_not_found(self, android_app): - recorder = self._instrument_service(statuses=[404], responses=['some error response']) + recorder = self._instrument_service(statuses=[404], responses=[NOT_FOUND_RESPONSE]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.NotFoundError) as excinfo: android_app.get_metadata() assert 'Failed to find the resource' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_set_display_name(self, android_app): @@ -965,14 +983,15 @@ def test_set_display_name(self, android_app): recorder[0], 'PATCH', TestAndroidApp._SET_DISPLAY_NAME_URL, body) def test_set_display_name_not_found(self, android_app): - recorder = self._instrument_service(statuses=[404], responses=['some error response']) + recorder = self._instrument_service(statuses=[404], responses=[NOT_FOUND_RESPONSE]) new_display_name = 'A new display name!' - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.NotFoundError) as excinfo: android_app.set_display_name(new_display_name) assert 'Failed to find the resource' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_get_config(self, android_app): @@ -985,13 +1004,14 @@ def test_get_config(self, android_app): self._assert_request_is_correct(recorder[0], 'GET', TestAndroidApp._GET_CONFIG_URL) def test_get_config_not_found(self, android_app): - recorder = self._instrument_service(statuses=[404], responses=['some error response']) + recorder = self._instrument_service(statuses=[404], responses=[NOT_FOUND_RESPONSE]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.NotFoundError) as excinfo: android_app.get_config() assert 'Failed to find the resource' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_get_sha_certificates(self, android_app): @@ -1005,13 +1025,14 @@ def test_get_sha_certificates(self, android_app): self._assert_request_is_correct(recorder[0], 'GET', TestAndroidApp._LIST_CERTS_URL) def test_get_sha_certificates_not_found(self, android_app): - recorder = self._instrument_service(statuses=[404], responses=['some error response']) + recorder = self._instrument_service(statuses=[404], responses=[NOT_FOUND_RESPONSE]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.NotFoundError) as excinfo: android_app.get_sha_certificates() assert 'Failed to find the resource' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_add_certificate_none_error(self, android_app): @@ -1042,14 +1063,15 @@ def test_add_sha_256_certificate(self, android_app): self._assert_request_is_correct(recorder[0], 'POST', TestAndroidApp._ADD_CERT_URL, body) def test_add_sha_certificates_already_exists(self, android_app): - recorder = self._instrument_service(statuses=[409], responses=['some error response']) + recorder = self._instrument_service(statuses=[409], responses=[ALREADY_EXISTS_RESPONSE]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.AlreadyExistsError) as excinfo: android_app.add_sha_certificate( project_management.ShaCertificate('123456789a123456789a123456789a123456789a')) assert 'The resource already exists' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_delete_certificate_none_error(self, android_app): @@ -1075,13 +1097,14 @@ def test_delete_sha_256_certificate(self, android_app): recorder[0], 'DELETE', TestAndroidApp._DELETE_SHA_256_CERT_URL) def test_delete_sha_certificates_not_found(self, android_app): - recorder = self._instrument_service(statuses=[404], responses=['some error response']) + recorder = self._instrument_service(statuses=[404], responses=[NOT_FOUND_RESPONSE]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.NotFoundError) as excinfo: android_app.delete_sha_certificate(SHA_1_CERTIFICATE) assert 'Failed to find the resource' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_raises_if_app_has_no_project_id(self): @@ -1137,21 +1160,24 @@ def test_get_metadata_unknown_error(self, ios_app): recorder = self._instrument_service( statuses=[428], responses=['precondition required error']) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnknownError) as excinfo: ios_app.get_metadata() - assert 'Error 428' in str(excinfo.value) - assert excinfo.value.detail is not None + message = 'Unexpected HTTP response with status: 428; body: precondition required error' + assert str(excinfo.value) == message + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_get_metadata_not_found(self, ios_app): - recorder = self._instrument_service(statuses=[404], responses=['some error response']) + recorder = self._instrument_service(statuses=[404], responses=[NOT_FOUND_RESPONSE]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.NotFoundError) as excinfo: ios_app.get_metadata() assert 'Failed to find the resource' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_set_display_name(self, ios_app): @@ -1166,14 +1192,15 @@ def test_set_display_name(self, ios_app): recorder[0], 'PATCH', TestIosApp._SET_DISPLAY_NAME_URL, body) def test_set_display_name_not_found(self, ios_app): - recorder = self._instrument_service(statuses=[404], responses=['some error response']) + recorder = self._instrument_service(statuses=[404], responses=[NOT_FOUND_RESPONSE]) new_display_name = 'A new display name!' - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.NotFoundError) as excinfo: ios_app.set_display_name(new_display_name) assert 'Failed to find the resource' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_get_config(self, ios_app): @@ -1186,13 +1213,14 @@ def test_get_config(self, ios_app): self._assert_request_is_correct(recorder[0], 'GET', TestIosApp._GET_CONFIG_URL) def test_get_config_not_found(self, ios_app): - recorder = self._instrument_service(statuses=[404], responses=['some error response']) + recorder = self._instrument_service(statuses=[404], responses=[NOT_FOUND_RESPONSE]) - with pytest.raises(project_management.ApiCallError) as excinfo: + with pytest.raises(exceptions.NotFoundError) as excinfo: ios_app.get_config() assert 'Failed to find the resource' in str(excinfo.value) - assert excinfo.value.detail is not None + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None assert len(recorder) == 1 def test_raises_if_app_has_no_project_id(self): From dbb6970dfe263a56f16620735cace56edc3788a1 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Fri, 2 Aug 2019 15:20:59 -0700 Subject: [PATCH 10/37] Error handling updated for remaining user_mgt APIs (#315) * Error handling updated for remaining user_mgt APIs * Removed unused constants --- firebase_admin/_user_mgt.py | 35 +++++++++-------------------------- firebase_admin/auth.py | 16 +++++----------- tests/test_user_mgt.py | 24 +++++++++++++++++++++--- 3 files changed, 35 insertions(+), 40 deletions(-) diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index 3910f9690..435594224 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -24,9 +24,6 @@ from firebase_admin import _user_import -USER_IMPORT_ERROR = 'USER_IMPORT_ERROR' -USER_DOWNLOAD_ERROR = 'LIST_USERS_ERROR' - MAX_LIST_USERS_RESULTS = 1000 MAX_IMPORT_USERS_SIZE = 1000 @@ -44,15 +41,6 @@ def __init__(self, description): DELETE_ATTRIBUTE = Sentinel('Value used to delete an attribute from a user profile') -class ApiCallError(Exception): - """Represents an Exception encountered while invoking the Firebase user management API.""" - - def __init__(self, code, message, error=None): - Exception.__init__(self, message) - self.code = code - self.detail = error - - class UserMetadata(object): """Contains additional metadata associated with a user account.""" @@ -510,7 +498,7 @@ def list_users(self, page_token=None, max_results=MAX_LIST_USERS_RESULTS): try: return self._client.body('get', '/accounts:batchGet', params=payload) except requests.exceptions.RequestException as error: - self._handle_http_error(USER_DOWNLOAD_ERROR, 'Failed to download user accounts.', error) + raise _auth_utils.handle_auth_backend_error(error) def create_user(self, uid=None, display_name=None, email=None, phone_number=None, photo_url=None, password=None, disabled=None, email_verified=None): @@ -619,13 +607,15 @@ def import_users(self, users, hash_alg=None): raise ValueError('A UserImportHash is required to import users with passwords.') payload.update(hash_alg.to_dict()) try: - response = self._client.body('post', '/accounts:batchCreate', json=payload) + body, http_resp = self._client.body_and_response( + 'post', '/accounts:batchCreate', json=payload) except requests.exceptions.RequestException as error: - self._handle_http_error(USER_IMPORT_ERROR, 'Failed to import users.', error) + raise _auth_utils.handle_auth_backend_error(error) else: - if not isinstance(response, dict): - raise ApiCallError(USER_IMPORT_ERROR, 'Failed to import users.') - return response + if not isinstance(body, dict): + raise _auth_utils.UnexpectedResponseError( + 'Failed to import users.', http_response=http_resp) + return body def generate_email_action_link(self, action_type, email, action_code_settings=None): """Fetches the email action links for types @@ -640,7 +630,7 @@ def generate_email_action_link(self, action_type, email, action_code_settings=No link_url: action url to be emailed to the user Raises: - ApiCallError: If an error occurs while generating the link + FirebaseError: If an error occurs while generating the link ValueError: If the provided arguments are invalid """ payload = { @@ -663,13 +653,6 @@ def generate_email_action_link(self, action_type, email, action_code_settings=No 'Failed to generate email action link.', http_response=http_resp) return body.get('oobLink') - def _handle_http_error(self, code, msg, error): - if error.response is not None: - msg += '\nServer response: {0}'.format(error.response.content.decode()) - else: - msg += '\nReason: {0}'.format(error) - raise ApiCallError(code, msg, error) - class _UserIterator(object): """An iterator that allows iterating over user accounts, one at a time. diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index 61d71ad9f..0913e50f7 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -308,14 +308,11 @@ def list_users(page_token=None, max_results=_user_mgt.MAX_LIST_USERS_RESULTS, ap Raises: ValueError: If max_results or page_token are invalid. - AuthError: If an error occurs while retrieving the user accounts. + FirebaseError: If an error occurs while retrieving the user accounts. """ user_manager = _get_auth_service(app).user_manager def download(page_token, max_results): - try: - return user_manager.list_users(page_token, max_results) - except _user_mgt.ApiCallError as error: - raise AuthError(error.code, str(error), error.detail) + return user_manager.list_users(page_token, max_results) return ListUsersPage(download, page_token, max_results) @@ -443,14 +440,11 @@ def import_users(users, hash_alg=None, app=None): Raises: ValueError: If the provided arguments are invalid. - AuthError: If an error occurs while importing users. + FirebaseError: If an error occurs while importing users. """ user_manager = _get_auth_service(app).user_manager - try: - result = user_manager.import_users(users, hash_alg) - return UserImportResult(result, len(users)) - except _user_mgt.ApiCallError as error: - raise AuthError(error.code, str(error), error.detail) + result = user_manager.import_users(users, hash_alg) + return UserImportResult(result, len(users)) def generate_password_reset_link(email, action_code_settings=None, app=None): diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index 594de2d4c..34e2b019b 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -731,10 +731,9 @@ def test_list_users_with_all_args(self, user_mgt_app): def test_list_users_error(self, user_mgt_app): _instrument_user_manager(user_mgt_app, 500, '{"error":"test"}') - with pytest.raises(auth.AuthError) as excinfo: + with pytest.raises(exceptions.InternalError) as excinfo: auth.list_users(app=user_mgt_app) - assert excinfo.value.code == _user_mgt.USER_DOWNLOAD_ERROR - assert '{"error":"test"}' in str(excinfo.value) + assert str(excinfo.value) == 'Unexpected error response: {"error":"test"}' def _check_page(self, page): assert isinstance(page, auth.ListUsersPage) @@ -1076,6 +1075,25 @@ def test_import_users_with_hash(self, user_mgt_app): } self._check_rpc_calls(recorder, expected) + def test_import_users_http_error(self, user_mgt_app): + _instrument_user_manager(user_mgt_app, 401, '{"error": {"message": "ERROR_CODE"}}') + users = [ + auth.ImportUserRecord(uid='user1'), + auth.ImportUserRecord(uid='user2'), + ] + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: + auth.import_users(users, app=user_mgt_app) + assert str(excinfo.value) == 'Error while calling Auth service (ERROR_CODE).' + + def test_import_users_unexpected_response(self, user_mgt_app): + _instrument_user_manager(user_mgt_app, 200, '"not dict"') + users = [ + auth.ImportUserRecord(uid='user1'), + auth.ImportUserRecord(uid='user2'), + ] + with pytest.raises(auth.UnexpectedResponseError): + auth.import_users(users, app=user_mgt_app) + def _check_rpc_calls(self, recorder, expected): assert len(recorder) == 1 request = json.loads(recorder[0].body.decode()) From 12107239c89854c89966af4399ab9e5dfa69cc25 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Mon, 5 Aug 2019 09:52:31 -0700 Subject: [PATCH 11/37] Migrated token verification APIs to new exception types (#317) * Migrated token verification APIs to new error types * Removed old AuthError type * Added new exception types for revoked tokens --- firebase_admin/_auth_utils.py | 2 +- firebase_admin/_token_gen.py | 100 ++++++++++++++++++++++++++++------ firebase_admin/auth.py | 54 ++++++++++-------- integration/test_auth.py | 6 +- snippets/auth/index.py | 55 +++++++++---------- tests/test_token_gen.py | 89 ++++++++++++++++++++++-------- 6 files changed, 207 insertions(+), 99 deletions(-) diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index 08b930ae6..c7b6c15f1 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -216,7 +216,7 @@ class InvalidIdTokenError(exceptions.InvalidArgumentError): default_message = 'The provided ID token is invalid' - def __init__(self, message, cause, http_response=None): + def __init__(self, message, cause=None, http_response=None): exceptions.InvalidArgumentError.__init__(self, message, cause, http_response) diff --git a/firebase_admin/_token_gen.py b/firebase_admin/_token_gen.py index 0ea34f77c..1d4556939 100644 --- a/firebase_admin/_token_gen.py +++ b/firebase_admin/_token_gen.py @@ -217,12 +217,18 @@ def __init__(self, app): project_id=app.project_id, short_name='ID token', operation='verify_id_token()', doc_url='https://firebase.google.com/docs/auth/admin/verify-id-tokens', - cert_url=ID_TOKEN_CERT_URI, issuer=ID_TOKEN_ISSUER_PREFIX) + cert_url=ID_TOKEN_CERT_URI, + issuer=ID_TOKEN_ISSUER_PREFIX, + invalid_token_error=_auth_utils.InvalidIdTokenError, + expired_token_error=ExpiredIdTokenError) self.cookie_verifier = _JWTVerifier( project_id=app.project_id, short_name='session cookie', operation='verify_session_cookie()', doc_url='https://firebase.google.com/docs/auth/admin/verify-id-tokens', - cert_url=COOKIE_CERT_URI, issuer=COOKIE_ISSUER_PREFIX) + cert_url=COOKIE_CERT_URI, + issuer=COOKIE_ISSUER_PREFIX, + invalid_token_error=InvalidSessionCookieError, + expired_token_error=ExpiredSessionCookieError) def verify_id_token(self, id_token): return self.id_token_verifier.verify(id_token, self.request) @@ -245,6 +251,8 @@ def __init__(self, **kwargs): self.articled_short_name = 'an {0}'.format(self.short_name) else: self.articled_short_name = 'a {0}'.format(self.short_name) + self._invalid_token_error = kwargs.pop('invalid_token_error') + self._expired_token_error = kwargs.pop('expired_token_error') def verify(self, token, request): """Verifies the signature and data for the provided JWT.""" @@ -261,8 +269,7 @@ def verify(self, token, request): 'or set your Firebase project ID as an app option. Alternatively set the ' 'GOOGLE_CLOUD_PROJECT environment variable.'.format(self.operation)) - header = jwt.decode_header(token) - payload = jwt.decode(token, verify=False) + header, payload = self._decode_unverified(token) issuer = payload.get('iss') audience = payload.get('aud') subject = payload.get('sub') @@ -275,12 +282,12 @@ def verify(self, token, request): 'See {0} for details on how to retrieve {1}.'.format(self.url, self.short_name)) error_message = None - if not header.get('kid'): - if audience == FIREBASE_AUDIENCE: - error_message = ( - '{0} expects {1}, but was given a custom ' - 'token.'.format(self.operation, self.articled_short_name)) - elif header.get('alg') == 'HS256' and payload.get( + if audience == FIREBASE_AUDIENCE: + error_message = ( + '{0} expects {1}, but was given a custom ' + 'token.'.format(self.operation, self.articled_short_name)) + elif not header.get('kid'): + if header.get('alg') == 'HS256' and payload.get( 'v') is 0 and 'uid' in payload.get('d', {}): error_message = ( '{0} expects {1}, but was given a legacy custom ' @@ -315,15 +322,30 @@ def verify(self, token, request): '{1}'.format(self.short_name, verify_id_token_msg)) if error_message: - raise ValueError(error_message) + raise self._invalid_token_error(error_message) + + try: + verified_claims = google.oauth2.id_token.verify_token( + token, + request=request, + audience=self.project_id, + certs_url=self.cert_url) + verified_claims['uid'] = verified_claims['sub'] + return verified_claims + except google.auth.exceptions.TransportError as error: + raise CertificateFetchError(str(error), cause=error) + except ValueError as error: + if 'Token expired' in str(error): + raise self._expired_token_error(str(error), cause=error) + raise self._invalid_token_error(str(error), cause=error) - verified_claims = google.oauth2.id_token.verify_token( - token, - request=request, - audience=self.project_id, - certs_url=self.cert_url) - verified_claims['uid'] = verified_claims['sub'] - return verified_claims + def _decode_unverified(self, token): + try: + header = jwt.decode_header(token) + payload = jwt.decode(token, verify=False) + return header, payload + except ValueError as error: + raise self._invalid_token_error(str(error), cause=error) class TokenSignError(exceptions.UnknownError): @@ -331,3 +353,45 @@ class TokenSignError(exceptions.UnknownError): def __init__(self, message, cause): exceptions.UnknownError.__init__(self, message, cause) + + +class CertificateFetchError(exceptions.UnknownError): + """Failed to fetch some public key certificates required to verify a token.""" + + def __init__(self, message, cause): + exceptions.UnknownError.__init__(self, message, cause) + + +class ExpiredIdTokenError(_auth_utils.InvalidIdTokenError): + """The provided ID token is expired.""" + + def __init__(self, message, cause): + _auth_utils.InvalidIdTokenError.__init__(self, message, cause) + + +class RevokedIdTokenError(_auth_utils.InvalidIdTokenError): + """The provided ID token has been revoked.""" + + def __init__(self, message): + _auth_utils.InvalidIdTokenError.__init__(self, message) + + +class InvalidSessionCookieError(exceptions.InvalidArgumentError): + """The provided string is not a valid Firebase session cookie.""" + + def __init__(self, message, cause=None): + exceptions.InvalidArgumentError.__init__(self, message, cause) + + +class ExpiredSessionCookieError(InvalidSessionCookieError): + """The provided session cookie is expired.""" + + def __init__(self, message, cause): + InvalidSessionCookieError.__init__(self, message, cause) + + +class RevokedSessionCookieError(InvalidSessionCookieError): + """The provided session cookie has been revoked.""" + + def __init__(self, message): + InvalidSessionCookieError.__init__(self, message) diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index 0913e50f7..bbc7c613a 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -31,20 +31,23 @@ _AUTH_ATTRIBUTE = '_auth' -_ID_TOKEN_REVOKED = 'ID_TOKEN_REVOKED' -_SESSION_COOKIE_REVOKED = 'SESSION_COOKIE_REVOKED' __all__ = [ 'ActionCodeSettings', - 'AuthError', + 'CertificateFetchError', 'DELETE_ATTRIBUTE', 'ErrorInfo', + 'ExpiredIdTokenError', + 'ExpiredSessionCookieError', 'ExportedUserRecord', 'ImportUserRecord', 'InvalidDynamicLinkDomainError', 'InvalidIdTokenError', + 'InvalidSessionCookieError', 'ListUsersPage', + 'RevokedIdTokenError', + 'RevokedSessionCookieError', 'TokenSignError', 'UidAlreadyExistsError', 'UnexpectedResponseError', @@ -76,17 +79,23 @@ ] ActionCodeSettings = _user_mgt.ActionCodeSettings +CertificateFetchError = _token_gen.CertificateFetchError DELETE_ATTRIBUTE = _user_mgt.DELETE_ATTRIBUTE ErrorInfo = _user_import.ErrorInfo +ExpiredIdTokenError = _token_gen.ExpiredIdTokenError +ExpiredSessionCookieError = _token_gen.ExpiredSessionCookieError ExportedUserRecord = _user_mgt.ExportedUserRecord -ListUsersPage = _user_mgt.ListUsersPage -UserImportHash = _user_import.UserImportHash ImportUserRecord = _user_import.ImportUserRecord InvalidDynamicLinkDomainError = _auth_utils.InvalidDynamicLinkDomainError InvalidIdTokenError = _auth_utils.InvalidIdTokenError +InvalidSessionCookieError = _token_gen.InvalidSessionCookieError +ListUsersPage = _user_mgt.ListUsersPage +RevokedIdTokenError = _token_gen.RevokedIdTokenError +RevokedSessionCookieError = _token_gen.RevokedSessionCookieError TokenSignError = _token_gen.TokenSignError UidAlreadyExistsError = _auth_utils.UidAlreadyExistsError UnexpectedResponseError = _auth_utils.UnexpectedResponseError +UserImportHash = _user_import.UserImportHash UserImportResult = _user_import.UserImportResult UserInfo = _user_mgt.UserInfo UserMetadata = _user_mgt.UserMetadata @@ -149,9 +158,12 @@ def verify_id_token(id_token, app=None, check_revoked=False): dict: A dictionary of key-value pairs parsed from the decoded JWT. Raises: - ValueError: If the JWT was found to be invalid, or if the App's project ID cannot - be determined. - AuthError: If ``check_revoked`` is requested and the token was revoked. + ValueError: If ``id_token`` is a not a string or is empty. + InvalidIdTokenError: If ``id_token`` is not a valid Firebase ID token. + ExpiredIdTokenError: If the specified ID token has expired. + RevokedIdTokenError: If ``check_revoked`` is ``True`` and the ID token has been revoked. + CertificateFetchError: If an error occurs while fetching the public key certificates + required to verify the ID token. """ if not isinstance(check_revoked, bool): # guard against accidental wrong assignment. @@ -160,7 +172,7 @@ def verify_id_token(id_token, app=None, check_revoked=False): token_verifier = _get_auth_service(app).token_verifier verified_claims = token_verifier.verify_id_token(id_token) if check_revoked: - _check_jwt_revoked(verified_claims, _ID_TOKEN_REVOKED, 'ID token', app) + _check_jwt_revoked(verified_claims, RevokedIdTokenError, 'ID token', app) return verified_claims @@ -201,14 +213,17 @@ def verify_session_cookie(session_cookie, check_revoked=False, app=None): dict: A dictionary of key-value pairs parsed from the decoded JWT. Raises: - ValueError: If the cookie was found to be invalid, or if the App's project ID cannot - be determined. - AuthError: If ``check_revoked`` is requested and the cookie was revoked. + ValueError: If ``session_cookie`` is a not a string or is empty. + InvalidSessionCookieError: If ``session_cookie`` is not a valid Firebase session cookie. + ExpiredSessionCookieError: If the specified session cookie has expired. + RevokedSessionCookieError: If ``check_revoked`` is ``True`` and the cookie has been revoked. + CertificateFetchError: If an error occurs while fetching the public key certificates + required to verify the session cookie. """ token_verifier = _get_auth_service(app).token_verifier verified_claims = token_verifier.verify_session_cookie(session_cookie) if check_revoked: - _check_jwt_revoked(verified_claims, _SESSION_COOKIE_REVOKED, 'session cookie', app) + _check_jwt_revoked(verified_claims, RevokedSessionCookieError, 'session cookie', app) return verified_claims @@ -513,19 +528,10 @@ def generate_sign_in_with_email_link(email, action_code_settings, app=None): 'EMAIL_SIGNIN', email, action_code_settings=action_code_settings) -def _check_jwt_revoked(verified_claims, error_code, label, app): +def _check_jwt_revoked(verified_claims, exc_type, label, app): user = get_user(verified_claims.get('uid'), app=app) if verified_claims.get('iat') * 1000 < user.tokens_valid_after_timestamp: - raise AuthError(error_code, 'The Firebase {0} has been revoked.'.format(label)) - - -class AuthError(Exception): - """Represents an Exception encountered while invoking the Firebase auth API.""" - - def __init__(self, code, message, error=None): - Exception.__init__(self, message) - self.code = code - self.detail = error + raise exc_type('The Firebase {0} has been revoked.'.format(label)) class _AuthService(object): diff --git a/integration/test_auth.py b/integration/test_auth.py index c0149fd69..eb1464476 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -351,9 +351,8 @@ def test_verify_id_token_revoked(new_user, api_key): # verify_id_token succeeded because it didn't check revoked. assert claims['iat'] * 1000 < user.tokens_valid_after_timestamp - with pytest.raises(auth.AuthError) as excinfo: + with pytest.raises(auth.RevokedIdTokenError) as excinfo: claims = auth.verify_id_token(id_token, check_revoked=True) - assert excinfo.value.code == auth._ID_TOKEN_REVOKED assert str(excinfo.value) == 'The Firebase ID token has been revoked.' # Sign in again, verify works. @@ -373,9 +372,8 @@ def test_verify_session_cookie_revoked(new_user, api_key): # verify_session_cookie succeeded because it didn't check revoked. assert claims['iat'] * 1000 < user.tokens_valid_after_timestamp - with pytest.raises(auth.AuthError) as excinfo: + with pytest.raises(auth.RevokedSessionCookieError) as excinfo: claims = auth.verify_session_cookie(session_cookie, check_revoked=True) - assert excinfo.value.code == auth._SESSION_COOKIE_REVOKED assert str(excinfo.value) == 'The Firebase session cookie has been revoked.' # Sign in again, verify works. diff --git a/snippets/auth/index.py b/snippets/auth/index.py index 5bfe21f8e..552875696 100644 --- a/snippets/auth/index.py +++ b/snippets/auth/index.py @@ -24,6 +24,7 @@ # [END import_sdk] from firebase_admin import credentials from firebase_admin import auth +from firebase_admin import exceptions sys.path.append("lib") @@ -31,6 +32,7 @@ def initialize_sdk_with_service_account(): # [START initialize_sdk_with_service_account] import firebase_admin from firebase_admin import credentials + from firebase_admin import exceptions cred = credentials.Certificate('path/to/serviceAccountKey.json') default_app = firebase_admin.initialize_app(cred) @@ -144,13 +146,12 @@ def verify_token_uid_check_revoke(id_token): decoded_token = auth.verify_id_token(id_token, check_revoked=True) # Token is valid and not revoked. uid = decoded_token['uid'] - except auth.AuthError as exc: - if exc.code == 'ID_TOKEN_REVOKED': - # Token revoked, inform the user to reauthenticate or signOut(). - pass - else: - # Token is invalid - pass + except auth.RevokedIdTokenError: + # Token revoked, inform the user to reauthenticate or signOut(). + pass + except auth.InvalidIdTokenError: + # Token is invalid + pass # [END verify_token_id_check_revoked] firebase_admin.delete_app(default_app) return uid @@ -322,7 +323,7 @@ def session_login(): response.set_cookie( 'session', session_cookie, expires=expires, httponly=True, secure=True) return response - except auth.AuthError: + except exceptions.FirebaseError: return flask.abort(401, 'Failed to create a session cookie') # [END session_login] @@ -344,9 +345,9 @@ def check_auth_time(id_token, flask): # User did not sign in recently. To guard against ID token theft, require # re-authentication. return flask.abort(401, 'Recent sign in required') - except ValueError: + except auth.InvalidIdTokenError: return flask.abort(401, 'Invalid ID token') - except auth.AuthError: + except exceptions.FirebaseError: return flask.abort(401, 'Failed to create a session cookie') # [END check_auth_time] @@ -359,16 +360,17 @@ def serve_content_for_user(decoded_claims): @app.route('/profile', methods=['POST']) def access_restricted_content(): session_cookie = flask.request.cookies.get('session') + if not session_cookie: + # Session cookie is unavailable. Force user to login. + return flask.redirect('/login') + # Verify the session cookie. In this case an additional check is added to detect # if the user's Firebase session was revoked, user deleted/disabled, etc. try: decoded_claims = auth.verify_session_cookie(session_cookie, check_revoked=True) return serve_content_for_user(decoded_claims) - except ValueError: - # Session cookie is unavailable or invalid. Force user to login. - return flask.redirect('/login') - except auth.AuthError: - # Session revoked. Force user to login. + except auth.InvalidSessionCookieError: + # Session cookie is invalid, expired or revoked. Force user to login. return flask.redirect('/login') # [END session_verify] @@ -385,11 +387,8 @@ def serve_content_for_admin(decoded_claims): return serve_content_for_admin(decoded_claims) else: return flask.abort(401, 'Insufficient permissions') - except ValueError: - # Session cookie is unavailable or invalid. Force user to login. - return flask.redirect('/login') - except auth.AuthError: - # Session revoked. Force user to login. + except auth.InvalidSessionCookieError: + # Session cookie is invalid, expired or revoked. Force user to login. return flask.redirect('/login') # [END session_verify_with_permission_check] @@ -413,7 +412,7 @@ def session_logout(): response = flask.make_response(flask.redirect('/login')) response.set_cookie('session', expires=0) return response - except ValueError: + except auth.InvalidSessionCookieError: return flask.redirect('/login') # [END session_clear_and_revoke] @@ -444,7 +443,7 @@ def import_users(): result.success_count, result.failure_count)) for err in result.errors: print('Failed to import {0} due to {1}'.format(users[err.index].uid, err.reason)) - except auth.AuthError: + except exceptions.FirebaseError: # Some unrecoverable error occurred that prevented the operation from running. pass # [END import_users] @@ -465,7 +464,7 @@ def import_with_hmac(): result = auth.import_users(users, hash_alg=hash_alg) for err in result.errors: print('Failed to import user:', err.reason) - except auth.AuthError as error: + except exceptions.FirebaseError as error: print('Error importing users:', error) # [END import_with_hmac] @@ -485,7 +484,7 @@ def import_with_pbkdf(): result = auth.import_users(users, hash_alg=hash_alg) for err in result.errors: print('Failed to import user:', err.reason) - except auth.AuthError as error: + except exceptions.FirebaseError as error: print('Error importing users:', error) # [END import_with_pbkdf] @@ -506,7 +505,7 @@ def import_with_standard_scrypt(): result = auth.import_users(users, hash_alg=hash_alg) for err in result.errors: print('Failed to import user:', err.reason) - except auth.AuthError as error: + except exceptions.FirebaseError as error: print('Error importing users:', error) # [END import_with_standard_scrypt] @@ -526,7 +525,7 @@ def import_with_bcrypt(): result = auth.import_users(users, hash_alg=hash_alg) for err in result.errors: print('Failed to import user:', err.reason) - except auth.AuthError as error: + except exceptions.FirebaseError as error: print('Error importing users:', error) # [END import_with_bcrypt] @@ -553,7 +552,7 @@ def import_with_scrypt(): result = auth.import_users(users, hash_alg=hash_alg) for err in result.errors: print('Failed to import user:', err.reason) - except auth.AuthError as error: + except exceptions.FirebaseError as error: print('Error importing users:', error) # [END import_with_scrypt] @@ -583,7 +582,7 @@ def import_without_password(): result = auth.import_users(users) for err in result.errors: print('Failed to import user:', err.reason) - except auth.AuthError as error: + except exceptions.FirebaseError as error: print('Error importing users:', error) # [END import_without_password] diff --git a/tests/test_token_gen.py b/tests/test_token_gen.py index baf8d9515..e016b8fb1 100644 --- a/tests/test_token_gen.py +++ b/tests/test_token_gen.py @@ -46,6 +46,15 @@ INVALID_STRINGS = [None, '', 0, 1, True, False, list(), tuple(), dict()] INVALID_BOOLS = [None, '', 'foo', 0, 1, list(), tuple(), dict()] +INVALID_JWT_ARGS = { + 'NoneToken': None, + 'EmptyToken': '', + 'BoolToken': True, + 'IntToken': 1, + 'ListToken': [], + 'EmptyDictToken': {}, + 'NonEmptyDictToken': {'a': 1}, +} # Fixture for mocking a HTTP server httpserver = plugin.httpserver @@ -363,13 +372,6 @@ class TestVerifyIdToken(object): 'iat': int(time.time()) - 10000, 'exp': int(time.time()) - 3600 }), - 'NoneToken': None, - 'EmptyToken': '', - 'BoolToken': True, - 'IntToken': 1, - 'ListToken': [], - 'EmptyDictToken': {}, - 'NonEmptyDictToken': {'a': 1}, 'BadFormatToken': 'foobar' } @@ -392,9 +394,8 @@ def test_valid_token_check_revoked(self, user_mgt_app, id_token): def test_revoked_token_check_revoked(self, user_mgt_app, revoked_tokens, id_token): _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) _instrument_user_manager(user_mgt_app, 200, revoked_tokens) - with pytest.raises(auth.AuthError) as excinfo: + with pytest.raises(auth.RevokedIdTokenError) as excinfo: auth.verify_id_token(id_token, app=user_mgt_app, check_revoked=True) - assert excinfo.value.code == 'ID_TOKEN_REVOKED' assert str(excinfo.value) == 'The Firebase ID token has been revoked.' @pytest.mark.parametrize('arg', INVALID_BOOLS) @@ -411,11 +412,30 @@ def test_revoked_token_do_not_check_revoked(self, user_mgt_app, revoked_tokens, assert claims['admin'] is True assert claims['uid'] == claims['sub'] + @pytest.mark.parametrize('id_token', INVALID_JWT_ARGS.values(), ids=list(INVALID_JWT_ARGS)) + def test_invalid_arg(self, user_mgt_app, id_token): + _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) + with pytest.raises(ValueError) as excinfo: + auth.verify_id_token(id_token, app=user_mgt_app) + assert 'Illegal ID token provided' in str(excinfo.value) + @pytest.mark.parametrize('id_token', invalid_tokens.values(), ids=list(invalid_tokens)) def test_invalid_token(self, user_mgt_app, id_token): _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) - with pytest.raises(ValueError): + with pytest.raises(auth.InvalidIdTokenError) as excinfo: auth.verify_id_token(id_token, app=user_mgt_app) + assert isinstance(excinfo.value, exceptions.InvalidArgumentError) + assert excinfo.value.http_response is None + + def test_expired_token(self, user_mgt_app): + _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) + id_token = self.invalid_tokens['ExpiredToken'] + with pytest.raises(auth.ExpiredIdTokenError) as excinfo: + auth.verify_id_token(id_token, app=user_mgt_app) + assert isinstance(excinfo.value, auth.InvalidIdTokenError) + assert 'Token expired' in str(excinfo.value) + assert excinfo.value.cause is not None + assert excinfo.value.http_response is None def test_project_id_option(self): app = firebase_admin.initialize_app( @@ -440,13 +460,19 @@ def test_project_id_env_var(self, env_var_app): def test_custom_token(self, auth_app): id_token = auth.create_custom_token(MOCK_UID, app=auth_app) _overwrite_cert_request(auth_app, MOCK_REQUEST) - with pytest.raises(ValueError): + with pytest.raises(auth.InvalidIdTokenError) as excinfo: auth.verify_id_token(id_token, app=auth_app) + message = 'verify_id_token() expects an ID token, but was given a custom token.' + assert str(excinfo.value) == message def test_certificate_request_failure(self, user_mgt_app): _overwrite_cert_request(user_mgt_app, testutils.MockRequest(404, 'not found')) - with pytest.raises(google.auth.exceptions.TransportError): + with pytest.raises(auth.CertificateFetchError) as excinfo: auth.verify_id_token(TEST_ID_TOKEN, app=user_mgt_app) + assert 'Could not fetch certificates' in str(excinfo.value) + assert isinstance(excinfo.value, exceptions.UnknownError) + assert excinfo.value.cause is not None + assert excinfo.value.http_response is None class TestVerifySessionCookie(object): @@ -471,13 +497,6 @@ class TestVerifySessionCookie(object): 'iat': int(time.time()) - 10000, 'exp': int(time.time()) - 3600 }), - 'NoneCookie': None, - 'EmptyCookie': '', - 'BoolCookie': True, - 'IntCookie': 1, - 'ListCookie': [], - 'EmptyDictCookie': {}, - 'NonEmptyDictCookie': {'a': 1}, 'BadFormatCookie': 'foobar', 'IDToken': TEST_ID_TOKEN, } @@ -501,9 +520,8 @@ def test_valid_cookie_check_revoked(self, user_mgt_app, cookie): def test_revoked_cookie_check_revoked(self, user_mgt_app, revoked_tokens, cookie): _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) _instrument_user_manager(user_mgt_app, 200, revoked_tokens) - with pytest.raises(auth.AuthError) as excinfo: + with pytest.raises(auth.RevokedSessionCookieError) as excinfo: auth.verify_session_cookie(cookie, app=user_mgt_app, check_revoked=True) - assert excinfo.value.code == 'SESSION_COOKIE_REVOKED' assert str(excinfo.value) == 'The Firebase session cookie has been revoked.' @pytest.mark.parametrize('cookie', valid_cookies.values(), ids=list(valid_cookies)) @@ -514,11 +532,30 @@ def test_revoked_cookie_does_not_check_revoked(self, user_mgt_app, revoked_token assert claims['admin'] is True assert claims['uid'] == claims['sub'] + @pytest.mark.parametrize('cookie', INVALID_JWT_ARGS.values(), ids=list(INVALID_JWT_ARGS)) + def test_invalid_args(self, user_mgt_app, cookie): + _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) + with pytest.raises(ValueError) as excinfo: + auth.verify_session_cookie(cookie, app=user_mgt_app) + assert 'Illegal session cookie provided' in str(excinfo.value) + @pytest.mark.parametrize('cookie', invalid_cookies.values(), ids=list(invalid_cookies)) def test_invalid_cookie(self, user_mgt_app, cookie): _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) - with pytest.raises(ValueError): + with pytest.raises(auth.InvalidSessionCookieError) as excinfo: auth.verify_session_cookie(cookie, app=user_mgt_app) + assert isinstance(excinfo.value, exceptions.InvalidArgumentError) + assert excinfo.value.http_response is None + + def test_expired_cookie(self, user_mgt_app): + _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) + cookie = self.invalid_cookies['ExpiredCookie'] + with pytest.raises(auth.ExpiredSessionCookieError) as excinfo: + auth.verify_session_cookie(cookie, app=user_mgt_app) + assert isinstance(excinfo.value, auth.InvalidSessionCookieError) + assert 'Token expired' in str(excinfo.value) + assert excinfo.value.cause is not None + assert excinfo.value.http_response is None def test_project_id_option(self): app = firebase_admin.initialize_app( @@ -540,13 +577,17 @@ def test_project_id_env_var(self, env_var_app): def test_custom_token(self, auth_app): custom_token = auth.create_custom_token(MOCK_UID, app=auth_app) _overwrite_cert_request(auth_app, MOCK_REQUEST) - with pytest.raises(ValueError): + with pytest.raises(auth.InvalidSessionCookieError): auth.verify_session_cookie(custom_token, app=auth_app) def test_certificate_request_failure(self, user_mgt_app): _overwrite_cert_request(user_mgt_app, testutils.MockRequest(404, 'not found')) - with pytest.raises(google.auth.exceptions.TransportError): + with pytest.raises(auth.CertificateFetchError) as excinfo: auth.verify_session_cookie(TEST_SESSION_COOKIE, app=user_mgt_app) + assert 'Could not fetch certificates' in str(excinfo.value) + assert isinstance(excinfo.value, exceptions.UnknownError) + assert excinfo.value.cause is not None + assert excinfo.value.http_response is None class TestCertificateCaching(object): From 299e80803c94d53c7957674cf1dd6cf4cf35b7d0 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Mon, 5 Aug 2019 14:02:17 -0700 Subject: [PATCH 12/37] Migrated the db module to the new exception types (#318) * Migrating db module to new exception types * Error handling for transactions * Updated integration tests * Restoring the old txn abort behavior * Updated error type in snippet * Added comment --- firebase_admin/_utils.py | 1 + firebase_admin/db.py | 95 ++++++++++++++++++------------------- integration/test_db.py | 33 ++++++------- snippets/database/index.py | 2 +- tests/test_db.py | 97 +++++++++++++++++++++++++++++++------- 5 files changed, 144 insertions(+), 84 deletions(-) diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index 42b83809e..95ed2c414 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -52,6 +52,7 @@ 403: exceptions.PERMISSION_DENIED, 404: exceptions.NOT_FOUND, 409: exceptions.CONFLICT, + 412: exceptions.FAILED_PRECONDITION, 429: exceptions.RESOURCE_EXHAUSTED, 500: exceptions.INTERNAL, 503: exceptions.UNAVAILABLE, diff --git a/firebase_admin/db.py b/firebase_admin/db.py index 53efd9b15..ef7c96721 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -32,6 +32,7 @@ from six.moves import urllib import firebase_admin +from firebase_admin import exceptions from firebase_admin import _http_client from firebase_admin import _sseclient from firebase_admin import _utils @@ -209,7 +210,7 @@ def get(self, etag=False, shallow=False): Raises: ValueError: If both ``etag`` and ``shallow`` are set to True. - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ if etag: if shallow: @@ -236,7 +237,7 @@ def get_if_changed(self, etag): Raises: ValueError: If the ETag is not a string. - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ if not isinstance(etag, six.string_types): raise ValueError('ETag must be a string.') @@ -258,7 +259,7 @@ def set(self, value): Raises: ValueError: If the provided value is None. TypeError: If the value is not JSON-serializable. - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ if value is None: raise ValueError('Value must not be None.') @@ -281,7 +282,7 @@ def set_if_unchanged(self, expected_etag, value): Raises: ValueError: If the value is None, or if expected_etag is not a string. - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ # pylint: disable=missing-raises-doc if not isinstance(expected_etag, six.string_types): @@ -293,11 +294,11 @@ def set_if_unchanged(self, expected_etag, value): headers = self._client.headers( 'put', self._add_suffix(), json=value, headers={'if-match': expected_etag}) return True, value, headers.get('ETag') - except ApiCallError as error: - detail = error.detail - if detail.response is not None and 'ETag' in detail.response.headers: - etag = detail.response.headers['ETag'] - snapshot = detail.response.json() + except exceptions.FailedPreconditionError as error: + http_response = error.http_response + if http_response is not None and 'ETag' in http_response.headers: + etag = http_response.headers['ETag'] + snapshot = http_response.json() return False, snapshot, etag else: raise error @@ -317,7 +318,7 @@ def push(self, value=''): Raises: ValueError: If the value is None. TypeError: If the value is not JSON-serializable. - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ if value is None: raise ValueError('Value must not be None.') @@ -333,7 +334,7 @@ def update(self, value): Raises: ValueError: If value is empty or not a dictionary. - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ if not value or not isinstance(value, dict): raise ValueError('Value argument must be a non-empty dictionary.') @@ -345,7 +346,7 @@ def delete(self): """Deletes this node from the database. Raises: - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ self._client.request('delete', self._add_suffix()) @@ -371,7 +372,7 @@ def listen(self, callback): ListenerRegistration: An object that can be used to stop the event listener. Raises: - ApiCallError: If an error occurs while starting the initial HTTP connection. + FirebaseError: If an error occurs while starting the initial HTTP connection. """ session = _sseclient.KeepAuthSession(self._client.credential) return self._listen_with_session(callback, session) @@ -387,9 +388,9 @@ def transaction(self, transaction_update): value of this reference into a new value. If another client writes to this location before the new value is successfully saved, the update function is called again with the new current value, and the write will be retried. In case of repeated failures, this method - will retry the transaction up to 25 times before giving up and raising a TransactionError. - The update function may also force an early abort by raising an exception instead of - returning a value. + will retry the transaction up to 25 times before giving up and raising a + TransactionAbortedError. The update function may also force an early abort by raising an + exception instead of returning a value. Args: transaction_update: A function which will be passed the current data stored at this @@ -402,7 +403,7 @@ def transaction(self, transaction_update): object: New value of the current database Reference (only if the transaction commits). Raises: - TransactionError: If the transaction aborts after exhausting all retry attempts. + TransactionAbortedError: If the transaction aborts after exhausting all retry attempts. ValueError: If transaction_update is not a function. """ if not callable(transaction_update): @@ -416,7 +417,8 @@ def transaction(self, transaction_update): if success: return new_data tries += 1 - raise TransactionError('Transaction aborted after failed retries.') + + raise TransactionAbortedError('Transaction aborted after failed retries.') def order_by_child(self, path): """Returns a Query that orders data by child values. @@ -468,7 +470,7 @@ def _listen_with_session(self, callback, session): sse = _sseclient.SSEClient(url, session) return ListenerRegistration(callback, sse) except requests.exceptions.RequestException as error: - raise ApiCallError(_Client.extract_error_message(error), error) + raise _Client.handle_rtdb_error(error) class Query(object): @@ -614,7 +616,7 @@ def get(self): object: Decoded JSON result of the Query. Raises: - ApiCallError: If an error occurs while communicating with the remote database server. + FirebaseError: If an error occurs while communicating with the remote database server. """ result = self._client.body('get', self._pathurl, params=self._querystr) if isinstance(result, (dict, list)) and self._order_by != '$priority': @@ -622,20 +624,11 @@ def get(self): return result -class ApiCallError(Exception): - """Represents an Exception encountered while invoking the Firebase database server API.""" - - def __init__(self, message, error): - Exception.__init__(self, message) - self.detail = error - - -class TransactionError(Exception): - """Represents an Exception encountered while performing a transaction.""" +class TransactionAbortedError(exceptions.AbortedError): + """A transaction was aborted aftr exceeding the maximum number of retries.""" def __init__(self, message): - Exception.__init__(self, message) - + exceptions.AbortedError.__init__(self, message) class _Sorter(object): @@ -934,7 +927,7 @@ def request(self, method, url, **kwargs): Response: An HTTP response object. Raises: - ApiCallError: If an error occurs while making the HTTP call. + FirebaseError: If an error occurs while making the HTTP call. """ query = '&'.join('{0}={1}'.format(key, self.params[key]) for key in self.params) extra_params = kwargs.get('params') @@ -950,33 +943,39 @@ def request(self, method, url, **kwargs): try: return super(_Client, self).request(method, url, **kwargs) except requests.exceptions.RequestException as error: - raise ApiCallError(_Client.extract_error_message(error), error) + raise _Client.handle_rtdb_error(error) + + @classmethod + def handle_rtdb_error(cls, error): + """Converts an error encountered while calling RTDB into a FirebaseError.""" + if error.response is None: + return _utils.handle_requests_error(error) + + message = cls._extract_error_message(error.response) + return _utils.handle_requests_error(error, message=message) @classmethod - def extract_error_message(cls, error): - """Extracts an error message from an exception. + def _extract_error_message(cls, response): + """Extracts an error message from an error response. - If the server has not sent any response, simply converts the exception into a string. If the server has sent a JSON response with an 'error' field, which is the typical behavior of the Realtime Database REST API, parses the response to retrieve the error message. If the server has sent a non-JSON response, returns the full response as the error message. - - Args: - error: An exception raised by the requests library. - - Returns: - str: A string error message extracted from the exception. """ - if error.response is None: - return str(error) + message = None try: - data = error.response.json() + # RTDB error format: {"error": "text message"} + data = response.json() if isinstance(data, dict): - return '{0}\nReason: {1}'.format(error, data.get('error', 'unknown')) + message = data.get('error') except ValueError: pass - return '{0}\nReason: {1}'.format(error, error.response.content.decode()) + + if not message: + message = 'Unexpected response from database: {0}'.format(response.content.decode()) + + return message class _EmulatorAdminCredentials(google.auth.credentials.Credentials): diff --git a/integration/test_db.py b/integration/test_db.py index d88d145ba..4c2f6bde2 100644 --- a/integration/test_db.py +++ b/integration/test_db.py @@ -22,6 +22,7 @@ import firebase_admin from firebase_admin import db +from firebase_admin import exceptions from integration import conftest from tests import testutils @@ -359,30 +360,26 @@ def init_ref(self, path, app): admin_ref.set('test') assert admin_ref.get() == 'test' - def check_permission_error(self, excinfo): - assert isinstance(excinfo.value, db.ApiCallError) - assert 'Reason: Permission denied' in str(excinfo.value) - def test_no_access(self, app, override_app): path = '_adminsdk/python/admin' self.init_ref(path, app) user_ref = db.reference(path, override_app) - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: assert user_ref.get() - self.check_permission_error(excinfo) + assert str(excinfo.value) == 'Permission denied' - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: user_ref.set('test2') - self.check_permission_error(excinfo) + assert str(excinfo.value) == 'Permission denied' def test_read(self, app, override_app): path = '_adminsdk/python/protected/user2' self.init_ref(path, app) user_ref = db.reference(path, override_app) assert user_ref.get() == 'test' - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: user_ref.set('test2') - self.check_permission_error(excinfo) + assert str(excinfo.value) == 'Permission denied' def test_read_write(self, app, override_app): path = '_adminsdk/python/protected/user1' @@ -394,9 +391,9 @@ def test_read_write(self, app, override_app): def test_query(self, override_app): user_ref = db.reference('_adminsdk/python/protected', override_app) - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: user_ref.order_by_key().limit_to_first(2).get() - self.check_permission_error(excinfo) + assert str(excinfo.value) == 'Permission denied' def test_none_auth_override(self, app, none_override_app): path = '_adminsdk/python/public' @@ -405,14 +402,14 @@ def test_none_auth_override(self, app, none_override_app): assert public_ref.get() == 'test' ref = db.reference('_adminsdk/python', none_override_app) - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: assert ref.child('protected/user1').get() - self.check_permission_error(excinfo) + assert str(excinfo.value) == 'Permission denied' - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: assert ref.child('protected/user2').get() - self.check_permission_error(excinfo) + assert str(excinfo.value) == 'Permission denied' - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.UnauthenticatedError) as excinfo: assert ref.child('admin').get() - self.check_permission_error(excinfo) + assert str(excinfo.value) == 'Permission denied' diff --git a/snippets/database/index.py b/snippets/database/index.py index fee23f626..adfa13476 100644 --- a/snippets/database/index.py +++ b/snippets/database/index.py @@ -214,7 +214,7 @@ def increment_votes(current_value): try: new_vote_count = upvotes_ref.transaction(increment_votes) print('Transaction completed') - except db.TransactionError: + except db.TransactionAbortedError: print('Transaction failed to commit') # [END transaction] diff --git a/tests/test_db.py b/tests/test_db.py index 211eabb4b..081c31e3d 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -22,6 +22,7 @@ import firebase_admin from firebase_admin import db +from firebase_admin import exceptions from firebase_admin import _sseclient from tests import testutils @@ -31,14 +32,15 @@ class MockAdapter(testutils.MockAdapter): ETAG = '0' - def __init__(self, data, status, recorder): + def __init__(self, data, status, recorder, etag=ETAG): testutils.MockAdapter.__init__(self, data, status, recorder) + self._etag = etag def send(self, request, **kwargs): if_match = request.headers.get('if-match') if_none_match = request.headers.get('if-none-match') resp = super(MockAdapter, self).send(request, **kwargs) - resp.headers = {'ETag': MockAdapter.ETAG} + resp.headers = {'ETag': self._etag} if if_match and if_match != MockAdapter.ETAG: resp.status_code = 412 elif if_none_match == MockAdapter.ETAG: @@ -125,6 +127,38 @@ def test_invalid_child(self, child): parent.child(child) +class _RefOperations(object): + """A collection of operations that can be performed using a ``db.Reference``. + + This can be used to test any functionality that is common across multiple API calls. + """ + + @classmethod + def get(cls, ref): + ref.get() + + @classmethod + def push(cls, ref): + ref.push() + + @classmethod + def set(cls, ref): + ref.set({'foo': 'bar'}) + + @classmethod + def delete(cls, ref): + ref.delete() + + @classmethod + def query(cls, ref): + query = ref.order_by_key() + query.get() + + @classmethod + def get_ops(cls): + return [cls.get, cls.push, cls.set, cls.delete, cls.query] + + class TestReference(object): """Test cases for database queries via References.""" @@ -132,6 +166,12 @@ class TestReference(object): valid_values = [ '', 'foo', 0, 1, 100, 1.2, True, False, [], [1, 2], {}, {'foo' : 'bar'} ] + error_codes = { + 400: exceptions.InvalidArgumentError, + 401: exceptions.UnauthenticatedError, + 404: exceptions.NotFoundError, + 500: exceptions.InternalError, + } @classmethod def setup_class(cls): @@ -141,9 +181,9 @@ def setup_class(cls): def teardown_class(cls): testutils.cleanup_apps() - def instrument(self, ref, payload, status=200): + def instrument(self, ref, payload, status=200, etag=MockAdapter.ETAG): recorder = [] - adapter = MockAdapter(payload, status, recorder) + adapter = MockAdapter(payload, status, recorder, etag) ref._client.session.mount(self.test_url, adapter) return recorder @@ -427,6 +467,19 @@ def transaction_update(data): assert len(recorder) == 1 assert recorder[0].method == 'GET' + def test_transaction_abort(self): + ref = db.reference('/test/count') + data = 42 + recorder = self.instrument(ref, json.dumps(data), etag='1') + + with pytest.raises(db.TransactionAbortedError) as excinfo: + ref.transaction(lambda x: x + 1 if x else 1) + assert isinstance(excinfo.value, exceptions.AbortedError) + assert str(excinfo.value) == 'Transaction aborted after failed retries.' + assert excinfo.value.cause is None + assert excinfo.value.http_response is None + assert len(recorder) == 1 + 25 + @pytest.mark.parametrize('func', [None, 0, 1, True, False, 'foo', dict(), list(), tuple()]) def test_transaction_invalid_function(self, func): ref = db.reference('/test') @@ -449,21 +502,29 @@ def test_get_reference(self, path, expected): else: assert ref.parent.path == parent - @pytest.mark.parametrize('error_code', [400, 401, 500]) - def test_server_error(self, error_code): + @pytest.mark.parametrize('error_code', error_codes.keys()) + @pytest.mark.parametrize('func', _RefOperations.get_ops()) + def test_server_error(self, error_code, func): ref = db.reference('/test') self.instrument(ref, json.dumps({'error' : 'json error message'}), error_code) - with pytest.raises(db.ApiCallError) as excinfo: - ref.get() - assert 'Reason: json error message' in str(excinfo.value) - - @pytest.mark.parametrize('error_code', [400, 401, 500]) - def test_other_error(self, error_code): + exc_type = self.error_codes[error_code] + with pytest.raises(exc_type) as excinfo: + func(ref) + assert str(excinfo.value) == 'json error message' + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None + + @pytest.mark.parametrize('error_code', error_codes.keys()) + @pytest.mark.parametrize('func', _RefOperations.get_ops()) + def test_other_error(self, error_code, func): ref = db.reference('/test') self.instrument(ref, 'custom error message', error_code) - with pytest.raises(db.ApiCallError) as excinfo: - ref.get() - assert 'Reason: custom error message' in str(excinfo.value) + exc_type = self.error_codes[error_code] + with pytest.raises(exc_type) as excinfo: + func(ref) + assert str(excinfo.value) == 'Unexpected response from database: custom error message' + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None class TestListenerRegistration(object): @@ -481,9 +542,11 @@ def test_listen_error(self): session.mount(test_url, adapter) def callback(_): pass - with pytest.raises(db.ApiCallError) as excinfo: + with pytest.raises(exceptions.InternalError) as excinfo: ref._listen_with_session(callback, session) - assert 'Reason: json error message' in str(excinfo.value) + assert str(excinfo.value) == 'json error message' + assert excinfo.value.cause is not None + assert excinfo.value.http_response is not None finally: testutils.cleanup_apps() From 030f6e66658f4816abf2c9f41ad58c97114665e3 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Thu, 8 Aug 2019 15:27:13 -0700 Subject: [PATCH 13/37] Adding a few overlooked error types (#319) * Adding some missing error types * Updated documentation --- firebase_admin/_auth_utils.py | 20 ++++++++++++++++++++ firebase_admin/auth.py | 4 ++++ firebase_admin/exceptions.py | 13 ++++++++++++- tests/test_user_mgt.py | 18 +++++++++++++----- 4 files changed, 49 insertions(+), 6 deletions(-) diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index c7b6c15f1..d90b494f5 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -202,6 +202,15 @@ def __init__(self, message, cause, http_response): exceptions.AlreadyExistsError.__init__(self, message, cause, http_response) +class EmailAlreadyExistsError(exceptions.AlreadyExistsError): + """The user with the provided email already exists.""" + + default_message = 'The user with the provided email already exists' + + def __init__(self, message, cause, http_response): + exceptions.AlreadyExistsError.__init__(self, message, cause, http_response) + + class InvalidDynamicLinkDomainError(exceptions.InvalidArgumentError): """Dynamic link domain in ActionCodeSettings is not authorized.""" @@ -220,6 +229,15 @@ def __init__(self, message, cause=None, http_response=None): exceptions.InvalidArgumentError.__init__(self, message, cause, http_response) +class PhoneNumberAlreadyExistsError(exceptions.AlreadyExistsError): + """The user with the provided phone number already exists.""" + + default_message = 'The user with the provided phone number already exists' + + def __init__(self, message, cause, http_response): + exceptions.AlreadyExistsError.__init__(self, message, cause, http_response) + + class UnexpectedResponseError(exceptions.UnknownError): """Backend service responded with an unexpected or malformed response.""" @@ -237,9 +255,11 @@ def __init__(self, message, cause=None, http_response=None): _CODE_TO_EXC_TYPE = { + 'DUPLICATE_EMAIL': EmailAlreadyExistsError, 'DUPLICATE_LOCAL_ID': UidAlreadyExistsError, 'INVALID_DYNAMIC_LINK_DOMAIN': InvalidDynamicLinkDomainError, 'INVALID_ID_TOKEN': InvalidIdTokenError, + 'PHONE_NUMBER_EXISTS': PhoneNumberAlreadyExistsError, 'USER_NOT_FOUND': UserNotFoundError, } diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index bbc7c613a..cddc8ab0d 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -37,6 +37,7 @@ 'ActionCodeSettings', 'CertificateFetchError', 'DELETE_ATTRIBUTE', + 'EmailAlreadyExistsError', 'ErrorInfo', 'ExpiredIdTokenError', 'ExpiredSessionCookieError', @@ -46,6 +47,7 @@ 'InvalidIdTokenError', 'InvalidSessionCookieError', 'ListUsersPage', + 'PhoneNumberAlreadyExistsError', 'RevokedIdTokenError', 'RevokedSessionCookieError', 'TokenSignError', @@ -81,6 +83,7 @@ ActionCodeSettings = _user_mgt.ActionCodeSettings CertificateFetchError = _token_gen.CertificateFetchError DELETE_ATTRIBUTE = _user_mgt.DELETE_ATTRIBUTE +EmailAlreadyExistsError = _auth_utils.EmailAlreadyExistsError ErrorInfo = _user_import.ErrorInfo ExpiredIdTokenError = _token_gen.ExpiredIdTokenError ExpiredSessionCookieError = _token_gen.ExpiredSessionCookieError @@ -90,6 +93,7 @@ InvalidIdTokenError = _auth_utils.InvalidIdTokenError InvalidSessionCookieError = _token_gen.InvalidSessionCookieError ListUsersPage = _user_mgt.ListUsersPage +PhoneNumberAlreadyExistsError = _auth_utils.PhoneNumberAlreadyExistsError RevokedIdTokenError = _token_gen.RevokedIdTokenError RevokedSessionCookieError = _token_gen.RevokedSessionCookieError TokenSignError = _token_gen.TokenSignError diff --git a/firebase_admin/exceptions.py b/firebase_admin/exceptions.py index f1297dbb3..bfc3fff1f 100644 --- a/firebase_admin/exceptions.py +++ b/firebase_admin/exceptions.py @@ -38,7 +38,18 @@ class FirebaseError(Exception): - """Base class for all errors raised by the Admin SDK.""" + """Base class for all errors raised by the Admin SDK. + + Args: + code: A string error code that represents the type of the exception. Possible error + codes are defined in https://cloud.google.com/apis/design/errors#handling_errors. + message: A human-readable error message string. + cause: The exception that caused this error (optional). + http_response: If this error was caused by an HTTP error response, this property is + set to the ``requests.Response`` object that represents the HTTP response (optional). + See https://2.python-requests.org/en/master/api/#requests.Response for details of + this object. + """ def __init__(self, code, message, cause=None, http_response=None): Exception.__init__(self, message) diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index 34e2b019b..3847ff1ab 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -289,6 +289,12 @@ def test_get_user_by_phone_http_error(self, user_mgt_app): class TestCreateUser(object): + already_exists_errors = { + 'DUPLICATE_EMAIL': auth.EmailAlreadyExistsError, + 'DUPLICATE_LOCAL_ID': auth.UidAlreadyExistsError, + 'PHONE_NUMBER_EXISTS': auth.PhoneNumberAlreadyExistsError, + } + @pytest.mark.parametrize('arg', INVALID_STRINGS[1:] + ['a'*129]) def test_invalid_uid(self, user_mgt_app, arg): with pytest.raises(ValueError): @@ -358,13 +364,15 @@ def test_create_user_error(self, user_mgt_app): assert excinfo.value.http_response is not None assert excinfo.value.cause is not None - def test_uid_already_exists(self, user_mgt_app): - _instrument_user_manager(user_mgt_app, 500, '{"error": {"message": "DUPLICATE_LOCAL_ID"}}') - with pytest.raises(auth.UidAlreadyExistsError) as excinfo: + @pytest.mark.parametrize('error_code', already_exists_errors.keys()) + def test_user_already_exists(self, user_mgt_app, error_code): + resp = {'error': {'message': error_code}} + _instrument_user_manager(user_mgt_app, 500, json.dumps(resp)) + exc_type = self.already_exists_errors[error_code] + with pytest.raises(exc_type) as excinfo: auth.create_user(app=user_mgt_app) assert isinstance(excinfo.value, exceptions.AlreadyExistsError) - assert str(excinfo.value) == ('The user with the provided uid already exists ' - '(DUPLICATE_LOCAL_ID).') + assert str(excinfo.value) == '{0} ({1}).'.format(exc_type.default_message, error_code) assert excinfo.value.http_response is not None assert excinfo.value.cause is not None From 7974c05bf09b0de03be09a2ab9c7e99de87b0ee0 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Fri, 9 Aug 2019 16:06:52 -0700 Subject: [PATCH 14/37] Removing the ability to delete user properties by passing None (#320) --- firebase_admin/_user_mgt.py | 26 +++++++++++--------------- tests/test_user_mgt.py | 14 ++------------ 2 files changed, 13 insertions(+), 27 deletions(-) diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index 435594224..867b6dd89 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -34,10 +34,6 @@ def __init__(self, description): self.description = description -# Use this internally, until sentinels are available in the public API. -_UNSPECIFIED = Sentinel('No value specified') - - DELETE_ATTRIBUTE = Sentinel('Value used to delete an attribute from a user profile') @@ -524,9 +520,9 @@ def create_user(self, uid=None, display_name=None, email=None, phone_number=None 'Failed to create new user.', http_response=http_resp) return body.get('localId') - def update_user(self, uid, display_name=_UNSPECIFIED, email=None, phone_number=_UNSPECIFIED, - photo_url=_UNSPECIFIED, password=None, disabled=None, email_verified=None, - valid_since=None, custom_claims=_UNSPECIFIED): + def update_user(self, uid, display_name=None, email=None, phone_number=None, + photo_url=None, password=None, disabled=None, email_verified=None, + valid_since=None, custom_claims=None): """Updates an existing user account with the specified properties""" payload = { 'localId': _auth_utils.validate_uid(uid, required=True), @@ -538,27 +534,27 @@ def update_user(self, uid, display_name=_UNSPECIFIED, email=None, phone_number=_ } remove = [] - if display_name is not _UNSPECIFIED: - if display_name is None or display_name is DELETE_ATTRIBUTE: + if display_name is not None: + if display_name is DELETE_ATTRIBUTE: remove.append('DISPLAY_NAME') else: payload['displayName'] = _auth_utils.validate_display_name(display_name) - if photo_url is not _UNSPECIFIED: - if photo_url is None or photo_url is DELETE_ATTRIBUTE: + if photo_url is not None: + if photo_url is DELETE_ATTRIBUTE: remove.append('PHOTO_URL') else: payload['photoUrl'] = _auth_utils.validate_photo_url(photo_url) if remove: payload['deleteAttribute'] = remove - if phone_number is not _UNSPECIFIED: - if phone_number is None or phone_number is DELETE_ATTRIBUTE: + if phone_number is not None: + if phone_number is DELETE_ATTRIBUTE: payload['deleteProvider'] = ['phone'] else: payload['phoneNumber'] = _auth_utils.validate_phone(phone_number) - if custom_claims is not _UNSPECIFIED: - if custom_claims is None or custom_claims is DELETE_ATTRIBUTE: + if custom_claims is not None: + if custom_claims is DELETE_ATTRIBUTE: custom_claims = {} json_claims = json.dumps(custom_claims) if isinstance( custom_claims, dict) else custom_claims diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index 3847ff1ab..dc71b6b6d 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -465,16 +465,6 @@ def test_delete_user_custom_claims(self, user_mgt_app): request = json.loads(recorder[0].body.decode()) assert request == {'localId' : 'testuser', 'customAttributes' : json.dumps({})} - def test_update_user_delete_fields_with_none(self, user_mgt_app): - user_mgt, recorder = _instrument_user_manager(user_mgt_app, 200, '{"localId":"testuser"}') - user_mgt.update_user('testuser', display_name=None, photo_url=None, phone_number=None) - request = json.loads(recorder[0].body.decode()) - assert request == { - 'localId' : 'testuser', - 'deleteAttribute' : ['DISPLAY_NAME', 'PHOTO_URL'], - 'deleteProvider' : ['phone'], - } - def test_update_user_delete_fields(self, user_mgt_app): user_mgt, recorder = _instrument_user_manager(user_mgt_app, 200, '{"localId":"testuser"}') user_mgt.update_user( @@ -561,9 +551,9 @@ def test_set_custom_user_claims_str(self, user_mgt_app): request = json.loads(recorder[0].body.decode()) assert request == {'localId' : 'testuser', 'customAttributes' : claims} - def test_set_custom_user_claims_none(self, user_mgt_app): + def test_set_custom_user_claims_remove(self, user_mgt_app): _, recorder = _instrument_user_manager(user_mgt_app, 200, '{"localId":"testuser"}') - auth.set_custom_user_claims('testuser', None, app=user_mgt_app) + auth.set_custom_user_claims('testuser', auth.DELETE_ATTRIBUTE, app=user_mgt_app) request = json.loads(recorder[0].body.decode()) assert request == {'localId' : 'testuser', 'customAttributes' : json.dumps({})} From dd3c4bd8dfedfd64425e02815dcf54bfd43c508a Mon Sep 17 00:00:00 2001 From: ifielker Date: Wed, 14 Aug 2019 17:42:03 -0400 Subject: [PATCH 15/37] Adding beginning of _MLKitService (#323) * Adding beginning of _MLKitService * Added License and Docstring --- firebase_admin/mlkit.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 firebase_admin/mlkit.py diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py new file mode 100644 index 000000000..9271fd668 --- /dev/null +++ b/firebase_admin/mlkit.py @@ -0,0 +1,25 @@ +# Copyright 2019 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Firebase ML Kit module. + +This module contains functions for creating, updating, getting, listing, +deleting, publishing and unpublishing Firebase ML Kit models. +""" + +class _MLKitService(object): + """Firebase MLKit service.""" + + BASE_URL = 'https://mlkit.googleapis.com' + PROJECT_URL = 'https://mlkit.googleapis.com/projects/{0}/' From 65f64c043aac762cf74eff86435b7901e6eb00b4 Mon Sep 17 00:00:00 2001 From: ifielker Date: Mon, 19 Aug 2019 16:28:17 -0400 Subject: [PATCH 16/37] Firebase ML Kit Get Model API implementation (#326) * added GetModel * Added tests for get_model --- .travis.yml | 1 - firebase_admin/mlkit.py | 73 +++++++++++++++++++++- tests/test_mlkit.py | 134 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 205 insertions(+), 3 deletions(-) create mode 100644 tests/test_mlkit.py diff --git a/.travis.yml b/.travis.yml index 4db3c3708..c89a1db76 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,7 +16,6 @@ before_install: - nvm install 8 && npm install -g firebase-tools script: - pytest - - firebase emulators:exec --only database --project fake-project-id 'pytest integration/test_db.py' cache: pip: true npm: true diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 9271fd668..e86a827e0 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -18,8 +18,77 @@ deleting, publishing and unpublishing Firebase ML Kit models. """ +import re +import requests +import six + +from firebase_admin import _http_client +from firebase_admin import _utils + + +_MLKIT_ATTRIBUTE = '_mlkit' + + +def _get_mlkit_service(app): + """ Returns an _MLKitService instance for an App. + + Args: + app: A Firebase App instance (or None to use the default App). + + Returns: + _MLKitService: An _MLKitService for the specified App instance. + + Raises: + ValueError: If the app argument is invalid. + """ + return _utils.get_app_service(app, _MLKIT_ATTRIBUTE, _MLKitService) + + +def get_model(model_id, app=None): + mlkit_service = _get_mlkit_service(app) + return Model(mlkit_service.get_model(model_id)) + + +class Model(object): + """A Firebase ML Kit Model object.""" + def __init__(self, data): + """Created from a data dictionary.""" + self._data = data + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self._data == other._data # pylint: disable=protected-access + else: + return False + + def __ne__(self, other): + return not self.__eq__(other) + + #TODO(ifielker): define the Model properties etc + + class _MLKitService(object): """Firebase MLKit service.""" - BASE_URL = 'https://mlkit.googleapis.com' - PROJECT_URL = 'https://mlkit.googleapis.com/projects/{0}/' + PROJECT_URL = 'https://mlkit.googleapis.com/v1beta1/projects/{0}/' + + def __init__(self, app): + project_id = app.project_id + if not project_id: + raise ValueError( + 'Project ID is required to access MLKit service. Either set the ' + 'projectId option, or use service account credentials.') + self._project_url = _MLKitService.PROJECT_URL.format(project_id) + self._client = _http_client.JsonHttpClient( + credential=app.credential.get_credential(), + base_url=self._project_url) + + def get_model(self, model_id): + if not isinstance(model_id, six.string_types): + raise TypeError('Model ID must be a string.') + if not re.match(r'^[A-Za-z0-9_-]{1,60}$', model_id): + raise ValueError('Model ID format is invalid.') + try: + return self._client.body('get', url='models/{0}'.format(model_id)) + except requests.exceptions.RequestException as error: + raise _utils.handle_platform_error_from_requests(error) diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py new file mode 100644 index 000000000..85edaa4a1 --- /dev/null +++ b/tests/test_mlkit.py @@ -0,0 +1,134 @@ +# Copyright 2019 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test cases for the firebase_admin.mlkit module.""" + +import json +import pytest + +import firebase_admin +from firebase_admin import exceptions +from firebase_admin import mlkit +from tests import testutils + +BASE_URL = 'https://mlkit.googleapis.com/v1beta1/' + +PROJECT_ID = 'myProject1' +MODEL_ID_1 = 'modelId1' +MODEL_NAME_1 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1) +DISPLAY_NAME_1 = 'displayName1' +MODEL_JSON_1 = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_1 +} +MODEL_1 = mlkit.Model(MODEL_JSON_1) +_DEFAULT_RESPONSE = json.dumps(MODEL_JSON_1) + +ERROR_CODE = 404 +ERROR_MSG = 'The resource was not found' +ERROR_STATUS = 'NOT_FOUND' +ERROR_JSON = { + 'error': { + 'code': ERROR_CODE, + 'message': ERROR_MSG, + 'status': ERROR_STATUS + } +} +_ERROR_RESPONSE = json.dumps(ERROR_JSON) + + +class TestGetModel(object): + """Tests mlkit.get_model.""" + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def check_error(err, err_type, msg): + assert isinstance(err, err_type) + assert str(err) == msg + + @staticmethod + def check_firebase_error(err, code, status, msg): + assert isinstance(err, exceptions.FirebaseError) + assert err.code == code + assert err.http_response is not None + assert err.http_response.status_code == status + assert str(err) == msg + + def _get_url(self, project_id, model_id): + return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + + def _instrument_mlkit_service(self, app=None, status=200, payload=_DEFAULT_RESPONSE): + if not app: + app = firebase_admin.get_app() + mlkit_service = mlkit._get_mlkit_service(app) + recorder = [] + mlkit_service._client.session.mount( + 'https://mlkit.googleapis.com', + testutils.MockAdapter(payload, status, recorder) + ) + return mlkit_service, recorder + + def test_get_model(self): + _, recorder = self._instrument_mlkit_service() + model = mlkit.get_model(MODEL_ID_1) + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == self._get_url(PROJECT_ID, MODEL_ID_1) + assert model == MODEL_1 + assert model._data['name'] == MODEL_NAME_1 + assert model._data['displayName'] == DISPLAY_NAME_1 + + def test_get_model_validation_errors(self): + #Empty model-id + with pytest.raises(ValueError) as err: + mlkit.get_model('') + self.check_error(err.value, ValueError, 'Model ID format is invalid.') + + #None model-id + with pytest.raises(TypeError) as err: + mlkit.get_model(None) + self.check_error(err.value, TypeError, 'Model ID must be a string.') + + #Wrong type + with pytest.raises(TypeError) as err: + mlkit.get_model(12345) + self.check_error(err.value, TypeError, 'Model ID must be a string.') + + #Invalid characters + with pytest.raises(ValueError) as err: + mlkit.get_model('&_*#@:/?') + self.check_error(err.value, ValueError, 'Model ID format is invalid.') + + def test_get_model_error(self): + _, recorder = self._instrument_mlkit_service(status=404, payload=_ERROR_RESPONSE) + with pytest.raises(exceptions.NotFoundError) as err: + mlkit.get_model(MODEL_ID_1) + self.check_firebase_error(err.value, ERROR_STATUS, ERROR_CODE, ERROR_MSG) + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == self._get_url(PROJECT_ID, MODEL_ID_1) + + def test_no_project_id(self): + def evaluate(): + app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') + with pytest.raises(ValueError): + mlkit.get_model(MODEL_ID_1, app) + testutils.run_without_project_id(evaluate) From a84d3f66bb5829c1b5148cbc6cc44b66bc54e433 Mon Sep 17 00:00:00 2001 From: ifielker Date: Mon, 19 Aug 2019 18:39:32 -0400 Subject: [PATCH 17/37] Firebase ML Kit Delete Model API implementation (#327) * implement delete model --- firebase_admin/mlkit.py | 24 +++++-- tests/test_mlkit.py | 138 +++++++++++++++++++++++++--------------- 2 files changed, 108 insertions(+), 54 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index e86a827e0..0e42419c6 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -49,6 +49,11 @@ def get_model(model_id, app=None): return Model(mlkit_service.get_model(model_id)) +def delete_model(model_id, app=None): + mlkit_service = _get_mlkit_service(app) + mlkit_service.delete_model(model_id) + + class Model(object): """A Firebase ML Kit Model object.""" def __init__(self, data): @@ -67,6 +72,13 @@ def __ne__(self, other): #TODO(ifielker): define the Model properties etc +def _validate_model_id(model_id): + if not isinstance(model_id, six.string_types): + raise TypeError('Model ID must be a string.') + if not re.match(r'^[A-Za-z0-9_-]{1,60}$', model_id): + raise ValueError('Model ID format is invalid.') + + class _MLKitService(object): """Firebase MLKit service.""" @@ -84,11 +96,15 @@ def __init__(self, app): base_url=self._project_url) def get_model(self, model_id): - if not isinstance(model_id, six.string_types): - raise TypeError('Model ID must be a string.') - if not re.match(r'^[A-Za-z0-9_-]{1,60}$', model_id): - raise ValueError('Model ID format is invalid.') + _validate_model_id(model_id) try: return self._client.body('get', url='models/{0}'.format(model_id)) except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) + + def delete_model(self, model_id): + _validate_model_id(model_id) + try: + self._client.body('delete', url='models/{0}'.format(model_id)) + except requests.exceptions.RequestException as error: + raise _utils.handle_platform_error_from_requests(error) diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 85edaa4a1..b0aadffc6 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -33,7 +33,8 @@ 'displayName': DISPLAY_NAME_1 } MODEL_1 = mlkit.Model(MODEL_JSON_1) -_DEFAULT_RESPONSE = json.dumps(MODEL_JSON_1) +_DEFAULT_GET_RESPONSE = json.dumps(MODEL_JSON_1) +_EMPTY_RESPONSE = json.dumps({}) ERROR_CODE = 404 ERROR_MSG = 'The resource was not found' @@ -47,6 +48,37 @@ } _ERROR_RESPONSE = json.dumps(ERROR_JSON) +invalid_model_id_args = [ + ('', ValueError, 'Model ID format is invalid.'), + ('&_*#@:/?', ValueError, 'Model ID format is invalid.'), + (None, TypeError, 'Model ID must be a string.'), + (12345, TypeError, 'Model ID must be a string.'), +] + +def check_error(err, err_type, msg): + assert isinstance(err, err_type) + assert str(err) == msg + + +def check_firebase_error(err, code, status, msg): + assert isinstance(err, exceptions.FirebaseError) + assert err.code == code + assert err.http_response is not None + assert err.http_response.status_code == status + assert str(err) == msg + + +def instrument_mlkit_service(app=None, status=200, payload=None): + if not app: + app = firebase_admin.get_app() + mlkit_service = mlkit._get_mlkit_service(app) + recorder = [] + mlkit_service._client.session.mount( + 'https://mlkit.googleapis.com', + testutils.MockAdapter(payload, status, recorder) + ) + return recorder + class TestGetModel(object): """Tests mlkit.get_model.""" @@ -60,71 +92,33 @@ def teardown_class(cls): testutils.cleanup_apps() @staticmethod - def check_error(err, err_type, msg): - assert isinstance(err, err_type) - assert str(err) == msg - - @staticmethod - def check_firebase_error(err, code, status, msg): - assert isinstance(err, exceptions.FirebaseError) - assert err.code == code - assert err.http_response is not None - assert err.http_response.status_code == status - assert str(err) == msg - - def _get_url(self, project_id, model_id): + def _url(project_id, model_id): return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) - def _instrument_mlkit_service(self, app=None, status=200, payload=_DEFAULT_RESPONSE): - if not app: - app = firebase_admin.get_app() - mlkit_service = mlkit._get_mlkit_service(app) - recorder = [] - mlkit_service._client.session.mount( - 'https://mlkit.googleapis.com', - testutils.MockAdapter(payload, status, recorder) - ) - return mlkit_service, recorder - def test_get_model(self): - _, recorder = self._instrument_mlkit_service() + recorder = instrument_mlkit_service(status=200, payload=_DEFAULT_GET_RESPONSE) model = mlkit.get_model(MODEL_ID_1) assert len(recorder) == 1 assert recorder[0].method == 'GET' - assert recorder[0].url == self._get_url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1) assert model == MODEL_1 assert model._data['name'] == MODEL_NAME_1 assert model._data['displayName'] == DISPLAY_NAME_1 - def test_get_model_validation_errors(self): - #Empty model-id - with pytest.raises(ValueError) as err: - mlkit.get_model('') - self.check_error(err.value, ValueError, 'Model ID format is invalid.') - - #None model-id - with pytest.raises(TypeError) as err: - mlkit.get_model(None) - self.check_error(err.value, TypeError, 'Model ID must be a string.') - - #Wrong type - with pytest.raises(TypeError) as err: - mlkit.get_model(12345) - self.check_error(err.value, TypeError, 'Model ID must be a string.') - - #Invalid characters - with pytest.raises(ValueError) as err: - mlkit.get_model('&_*#@:/?') - self.check_error(err.value, ValueError, 'Model ID format is invalid.') + @pytest.mark.parametrize('model_id, exc_type, error_message', invalid_model_id_args) + def test_get_model_validation_errors(self, model_id, exc_type, error_message): + with pytest.raises(exc_type) as err: + mlkit.get_model(model_id) + check_error(err.value, exc_type, error_message) def test_get_model_error(self): - _, recorder = self._instrument_mlkit_service(status=404, payload=_ERROR_RESPONSE) + recorder = instrument_mlkit_service(status=404, payload=_ERROR_RESPONSE) with pytest.raises(exceptions.NotFoundError) as err: mlkit.get_model(MODEL_ID_1) - self.check_firebase_error(err.value, ERROR_STATUS, ERROR_CODE, ERROR_MSG) + check_firebase_error(err.value, ERROR_STATUS, ERROR_CODE, ERROR_MSG) assert len(recorder) == 1 assert recorder[0].method == 'GET' - assert recorder[0].url == self._get_url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].url == self._url(PROJECT_ID, MODEL_ID_1) def test_no_project_id(self): def evaluate(): @@ -132,3 +126,47 @@ def evaluate(): with pytest.raises(ValueError): mlkit.get_model(MODEL_ID_1, app) testutils.run_without_project_id(evaluate) + +class TestDeleteModel(object): + """Tests mlkit.delete_model.""" + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def _url(project_id, model_id): + return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + + def test_delete_model(self): + recorder = instrument_mlkit_service(status=200, payload=_EMPTY_RESPONSE) + mlkit.delete_model(MODEL_ID_1) # no response for delete + assert len(recorder) == 1 + assert recorder[0].method == 'DELETE' + assert recorder[0].url == TestDeleteModel._url(PROJECT_ID, MODEL_ID_1) + + @pytest.mark.parametrize('model_id, exc_type, error_message', invalid_model_id_args) + def test_delete_model_validation_errors(self, model_id, exc_type, error_message): + with pytest.raises(exc_type) as err: + mlkit.delete_model(model_id) + check_error(err.value, exc_type, error_message) + + def test_delete_model_error(self): + recorder = instrument_mlkit_service(status=404, payload=_ERROR_RESPONSE) + with pytest.raises(exceptions.NotFoundError) as err: + mlkit.delete_model(MODEL_ID_1) + check_firebase_error(err.value, ERROR_STATUS, ERROR_CODE, ERROR_MSG) + assert len(recorder) == 1 + assert recorder[0].method == 'DELETE' + assert recorder[0].url == self._url(PROJECT_ID, MODEL_ID_1) + + def test_no_project_id(self): + def evaluate(): + app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') + with pytest.raises(ValueError): + mlkit.delete_model(MODEL_ID_1, app) + testutils.run_without_project_id(evaluate) From a247f133a62806d42bf9d2538e100d7b2cfcae04 Mon Sep 17 00:00:00 2001 From: ifielker Date: Wed, 21 Aug 2019 13:57:24 -0400 Subject: [PATCH 18/37] Firebase ML Kit List Models API implementation (#331) * implemented list models plus tests --- firebase_admin/mlkit.py | 148 +++++++++++++++++++++- tests/test_mlkit.py | 266 +++++++++++++++++++++++++++++++++++++--- 2 files changed, 394 insertions(+), 20 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 0e42419c6..dd02ea8db 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -27,6 +27,7 @@ _MLKIT_ATTRIBUTE = '_mlkit' +_MAX_PAGE_SIZE = 100 def _get_mlkit_service(app): @@ -49,6 +50,12 @@ def get_model(model_id, app=None): return Model(mlkit_service.get_model(model_id)) +def list_models(list_filter=None, page_size=None, page_token=None, app=None): + mlkit_service = _get_mlkit_service(app) + return ListModelsPage( + mlkit_service.list_models, list_filter, page_size, page_token) + + def delete_model(model_id, app=None): mlkit_service = _get_mlkit_service(app) mlkit_service.delete_model(model_id) @@ -69,7 +76,107 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) - #TODO(ifielker): define the Model properties etc + @property + def name(self): + return self._data['name'] + + @property + def display_name(self): + return self._data['displayName'] + + #TODO(ifielker): define the rest of the Model properties etc + + +class ListModelsPage(object): + """Represents a page of models in a firebase project. + + Provides methods for traversing the models included in this page, as well as + retrieving subsequent pages of models. The iterator returned by + ``iterate_all()`` can be used to iterate through all the models in the + Firebase project starting from this page. + """ + def __init__(self, list_models_func, list_filter, page_size, page_token): + self._list_models_func = list_models_func + self._list_filter = list_filter + self._page_size = page_size + self._page_token = page_token + self._list_response = list_models_func(list_filter, page_size, page_token) + + @property + def models(self): + """A list of Models from this page.""" + return [Model(model) for model in self._list_response.get('models', [])] + + @property + def list_filter(self): + """The filter string used to filter the models.""" + return self._list_filter + + @property + def next_page_token(self): + return self._list_response.get('nextPageToken', '') + + @property + def has_next_page(self): + """A boolean indicating whether more pages are available.""" + return bool(self.next_page_token) + + def get_next_page(self): + """Retrieves the next page of models if available. + + Returns: + ListModelsPage: Next page of models, or None if this is the last page. + """ + if self.has_next_page: + return ListModelsPage( + self._list_models_func, + self._list_filter, + self._page_size, + self.next_page_token) + return None + + def iterate_all(self): + """Retrieves an iterator for Models. + + Returned iterator will iterate through all the models in the Firebase + project starting from this page. The iterator will never buffer more than + one page of models in memory at a time. + + Returns: + iterator: An iterator of Model instances. + """ + return _ModelIterator(self) + + +class _ModelIterator(object): + """An iterator that allows iterating over models, one at a time. + + This implementation loads a page of models into memory, and iterates on them. + When the whole page has been traversed, it loads another page. This class + never keeps more than one page of entries in memory. + """ + def __init__(self, current_page): + if not isinstance(current_page, ListModelsPage): + raise TypeError('Current page must be a ListModelsPage') + self._current_page = current_page + self._index = 0 + + def next(self): + if self._index == len(self._current_page.models): + if self._current_page.has_next_page: + self._current_page = self._current_page.get_next_page() + self._index = 0 + if self._index < len(self._current_page.models): + result = self._current_page.models[self._index] + self._index += 1 + return result + raise StopIteration + + def __next__(self): + return self.next() + + def __iter__(self): + return self def _validate_model_id(model_id): @@ -79,6 +186,28 @@ def _validate_model_id(model_id): raise ValueError('Model ID format is invalid.') +def _validate_list_filter(list_filter): + if list_filter is not None: + if not isinstance(list_filter, six.string_types): + raise TypeError('List filter must be a string or None.') + + +def _validate_page_size(page_size): + if page_size is not None: + if type(page_size) is not int: # pylint: disable=unidiomatic-typecheck + # Specifically type() to disallow boolean which is a subtype of int + raise TypeError('Page size must be a number or None.') + if page_size < 1 or page_size > _MAX_PAGE_SIZE: + raise ValueError('Page size must be a positive integer between ' + '1 and {0}'.format(_MAX_PAGE_SIZE)) + + +def _validate_page_token(page_token): + if page_token is not None: + if not isinstance(page_token, six.string_types): + raise TypeError('Page token must be a string or None.') + + class _MLKitService(object): """Firebase MLKit service.""" @@ -102,6 +231,23 @@ def get_model(self, model_id): except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) + def list_models(self, list_filter, page_size, page_token): + """ lists Firebase ML Kit models.""" + _validate_list_filter(list_filter) + _validate_page_size(page_size) + _validate_page_token(page_token) + payload = {} + if list_filter: + payload['list_filter'] = list_filter + if page_size: + payload['page_size'] = page_size + if page_token: + payload['page_token'] = page_token + try: + return self._client.body('get', url='models', json=payload) + except requests.exceptions.RequestException as error: + raise _utils.handle_platform_error_from_requests(error) + def delete_model(self, model_id): _validate_model_id(model_id) try: diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index b0aadffc6..e14bd4371 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -25,6 +25,8 @@ BASE_URL = 'https://mlkit.googleapis.com/v1beta1/' PROJECT_ID = 'myProject1' +PAGE_TOKEN = 'pageToken' +NEXT_PAGE_TOKEN = 'nextPageToken' MODEL_ID_1 = 'modelId1' MODEL_NAME_1 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1) DISPLAY_NAME_1 = 'displayName1' @@ -33,20 +35,62 @@ 'displayName': DISPLAY_NAME_1 } MODEL_1 = mlkit.Model(MODEL_JSON_1) -_DEFAULT_GET_RESPONSE = json.dumps(MODEL_JSON_1) -_EMPTY_RESPONSE = json.dumps({}) -ERROR_CODE = 404 -ERROR_MSG = 'The resource was not found' -ERROR_STATUS = 'NOT_FOUND' -ERROR_JSON = { +MODEL_ID_2 = 'modelId2' +MODEL_NAME_2 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_2) +DISPLAY_NAME_2 = 'displayName2' +MODEL_JSON_2 = { + 'name': MODEL_NAME_2, + 'displayName': DISPLAY_NAME_2 +} +MODEL_2 = mlkit.Model(MODEL_JSON_2) + +MODEL_ID_3 = 'modelId3' +MODEL_NAME_3 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_3) +DISPLAY_NAME_3 = 'displayName3' +MODEL_JSON_3 = { + 'name': MODEL_NAME_3, + 'displayName': DISPLAY_NAME_3 +} +MODEL_3 = mlkit.Model(MODEL_JSON_3) + +EMPTY_RESPONSE = json.dumps({}) +DEFAULT_GET_RESPONSE = json.dumps(MODEL_JSON_1) +NO_MODELS_LIST_RESPONSE = json.dumps({}) +DEFAULT_LIST_RESPONSE = json.dumps({ + 'models': [MODEL_JSON_1, MODEL_JSON_2], + 'nextPageToken': NEXT_PAGE_TOKEN +}) +LAST_PAGE_LIST_RESPONSE = json.dumps({ + 'models': [MODEL_JSON_3] +}) +ONE_PAGE_LIST_RESPONSE = json.dumps({ + 'models': [MODEL_JSON_1, MODEL_JSON_2, MODEL_JSON_3], +}) + +ERROR_CODE_NOT_FOUND = 404 +ERROR_MSG_NOT_FOUND = 'The resource was not found' +ERROR_STATUS_NOT_FOUND = 'NOT_FOUND' +ERROR_JSON_NOT_FOUND = { 'error': { - 'code': ERROR_CODE, - 'message': ERROR_MSG, - 'status': ERROR_STATUS + 'code': ERROR_CODE_NOT_FOUND, + 'message': ERROR_MSG_NOT_FOUND, + 'status': ERROR_STATUS_NOT_FOUND } } -_ERROR_RESPONSE = json.dumps(ERROR_JSON) +ERROR_RESPONSE_NOT_FOUND = json.dumps(ERROR_JSON_NOT_FOUND) + +ERROR_CODE_BAD_REQUEST = 400 +ERROR_MSG_BAD_REQUEST = 'Invalid Argument' +ERROR_STATUS_BAD_REQUEST = 'INVALID_ARGUMENT' +ERROR_JSON_BAD_REQUEST = { + 'error': { + 'code': ERROR_CODE_BAD_REQUEST, + 'message': ERROR_MSG_BAD_REQUEST, + 'status': ERROR_STATUS_BAD_REQUEST + } +} +ERROR_RESPONSE_BAD_REQUEST = json.dumps(ERROR_JSON_BAD_REQUEST) invalid_model_id_args = [ ('', ValueError, 'Model ID format is invalid.'), @@ -54,6 +98,20 @@ (None, TypeError, 'Model ID must be a string.'), (12345, TypeError, 'Model ID must be a string.'), ] +PAGE_SIZE_VALUE_ERROR_MSG = 'Page size must be a positive integer between ' \ + '1 and {0}'.format(mlkit._MAX_PAGE_SIZE) +invalid_page_size_args = [ + ('abc', TypeError, 'Page size must be a number or None.'), + (4.2, TypeError, 'Page size must be a number or None.'), + (list(), TypeError, 'Page size must be a number or None.'), + (dict(), TypeError, 'Page size must be a number or None.'), + (True, TypeError, 'Page size must be a number or None.'), + (-1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), + (0, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), + (mlkit._MAX_PAGE_SIZE + 1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG) +] +invalid_string_or_none_args = [0, -1, 4.2, 0x10, False, list(), dict()] + def check_error(err, err_type, msg): assert isinstance(err, err_type) @@ -96,14 +154,14 @@ def _url(project_id, model_id): return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) def test_get_model(self): - recorder = instrument_mlkit_service(status=200, payload=_DEFAULT_GET_RESPONSE) + recorder = instrument_mlkit_service(status=200, payload=DEFAULT_GET_RESPONSE) model = mlkit.get_model(MODEL_ID_1) assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1) assert model == MODEL_1 - assert model._data['name'] == MODEL_NAME_1 - assert model._data['displayName'] == DISPLAY_NAME_1 + assert model.name == MODEL_NAME_1 + assert model.display_name == DISPLAY_NAME_1 @pytest.mark.parametrize('model_id, exc_type, error_message', invalid_model_id_args) def test_get_model_validation_errors(self, model_id, exc_type, error_message): @@ -112,13 +170,18 @@ def test_get_model_validation_errors(self, model_id, exc_type, error_message): check_error(err.value, exc_type, error_message) def test_get_model_error(self): - recorder = instrument_mlkit_service(status=404, payload=_ERROR_RESPONSE) + recorder = instrument_mlkit_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) with pytest.raises(exceptions.NotFoundError) as err: mlkit.get_model(MODEL_ID_1) - check_firebase_error(err.value, ERROR_STATUS, ERROR_CODE, ERROR_MSG) + check_firebase_error( + err.value, + ERROR_STATUS_NOT_FOUND, + ERROR_CODE_NOT_FOUND, + ERROR_MSG_NOT_FOUND + ) assert len(recorder) == 1 assert recorder[0].method == 'GET' - assert recorder[0].url == self._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1) def test_no_project_id(self): def evaluate(): @@ -143,7 +206,7 @@ def _url(project_id, model_id): return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) def test_delete_model(self): - recorder = instrument_mlkit_service(status=200, payload=_EMPTY_RESPONSE) + recorder = instrument_mlkit_service(status=200, payload=EMPTY_RESPONSE) mlkit.delete_model(MODEL_ID_1) # no response for delete assert len(recorder) == 1 assert recorder[0].method == 'DELETE' @@ -156,10 +219,15 @@ def test_delete_model_validation_errors(self, model_id, exc_type, error_message) check_error(err.value, exc_type, error_message) def test_delete_model_error(self): - recorder = instrument_mlkit_service(status=404, payload=_ERROR_RESPONSE) + recorder = instrument_mlkit_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) with pytest.raises(exceptions.NotFoundError) as err: mlkit.delete_model(MODEL_ID_1) - check_firebase_error(err.value, ERROR_STATUS, ERROR_CODE, ERROR_MSG) + check_firebase_error( + err.value, + ERROR_STATUS_NOT_FOUND, + ERROR_CODE_NOT_FOUND, + ERROR_MSG_NOT_FOUND + ) assert len(recorder) == 1 assert recorder[0].method == 'DELETE' assert recorder[0].url == self._url(PROJECT_ID, MODEL_ID_1) @@ -170,3 +238,163 @@ def evaluate(): with pytest.raises(ValueError): mlkit.delete_model(MODEL_ID_1, app) testutils.run_without_project_id(evaluate) + + +class TestListModels(object): + """Tests mlkit.list_models.""" + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def _url(project_id): + return BASE_URL + 'projects/{0}/models'.format(project_id) + + @staticmethod + def _check_page(page, model_count): + assert isinstance(page, mlkit.ListModelsPage) + assert len(page.models) == model_count + for model in page.models: + assert isinstance(model, mlkit.Model) + + def test_list_models_no_args(self): + recorder = instrument_mlkit_service(status=200, payload=DEFAULT_LIST_RESPONSE) + models_page = mlkit.list_models() + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == TestListModels._url(PROJECT_ID) + TestListModels._check_page(models_page, 2) + assert models_page.has_next_page + assert models_page.next_page_token == NEXT_PAGE_TOKEN + assert models_page.models[0] == MODEL_1 + assert models_page.models[1] == MODEL_2 + + def test_list_models_with_all_args(self): + recorder = instrument_mlkit_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) + models_page = mlkit.list_models( + 'display_name=displayName3', + page_size=10, + page_token=PAGE_TOKEN) + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == TestListModels._url(PROJECT_ID) + assert json.loads(recorder[0].body.decode()) == { + 'list_filter': 'display_name=displayName3', + 'page_size': 10, + 'page_token': PAGE_TOKEN + } + assert isinstance(models_page, mlkit.ListModelsPage) + assert len(models_page.models) == 1 + assert models_page.models[0] == MODEL_3 + assert not models_page.has_next_page + + @pytest.mark.parametrize('list_filter', invalid_string_or_none_args) + def test_list_models_list_filter_validation(self, list_filter): + with pytest.raises(TypeError) as err: + mlkit.list_models(list_filter=list_filter) + check_error(err.value, TypeError, 'List filter must be a string or None.') + + @pytest.mark.parametrize('page_size, exc_type, error_message', invalid_page_size_args) + def test_list_models_page_size_validation(self, page_size, exc_type, error_message): + with pytest.raises(exc_type) as err: + mlkit.list_models(page_size=page_size) + check_error(err.value, exc_type, error_message) + + @pytest.mark.parametrize('page_token', invalid_string_or_none_args) + def test_list_models_page_token_validation(self, page_token): + with pytest.raises(TypeError) as err: + mlkit.list_models(page_token=page_token) + check_error(err.value, TypeError, 'Page token must be a string or None.') + + def test_list_models_error(self): + recorder = instrument_mlkit_service(status=400, payload=ERROR_RESPONSE_BAD_REQUEST) + with pytest.raises(exceptions.InvalidArgumentError) as err: + mlkit.list_models() + check_firebase_error( + err.value, + ERROR_STATUS_BAD_REQUEST, + ERROR_CODE_BAD_REQUEST, + ERROR_MSG_BAD_REQUEST + ) + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == TestListModels._url(PROJECT_ID) + + def test_no_project_id(self): + def evaluate(): + app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') + with pytest.raises(ValueError): + mlkit.list_models(app=app) + testutils.run_without_project_id(evaluate) + + def test_list_single_page(self): + recorder = instrument_mlkit_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) + models_page = mlkit.list_models() + assert len(recorder) == 1 + assert models_page.next_page_token == '' + assert models_page.has_next_page is False + assert models_page.get_next_page() is None + models = [model for model in models_page.iterate_all()] + assert len(models) == 1 + + def test_list_multiple_pages(self): + # Page 1 + recorder = instrument_mlkit_service(status=200, payload=DEFAULT_LIST_RESPONSE) + page = mlkit.list_models() + assert len(recorder) == 1 + assert len(page.models) == 2 + assert page.next_page_token == NEXT_PAGE_TOKEN + assert page.has_next_page is True + + # Page 2 + recorder = instrument_mlkit_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) + page_2 = page.get_next_page() + assert len(recorder) == 1 + assert len(page_2.models) == 1 + assert page_2.next_page_token == '' + assert page_2.has_next_page is False + assert page_2.get_next_page() is None + + def test_list_models_paged_iteration(self): + # Page 1 + recorder = instrument_mlkit_service(status=200, payload=DEFAULT_LIST_RESPONSE) + page = mlkit.list_models() + assert page.next_page_token == NEXT_PAGE_TOKEN + assert page.has_next_page is True + iterator = page.iterate_all() + for index in range(2): + model = next(iterator) + assert model.display_name == 'displayName{0}'.format(index+1) + assert len(recorder) == 1 + + # Page 2 + recorder = instrument_mlkit_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) + model = next(iterator) + assert model.display_name == DISPLAY_NAME_3 + with pytest.raises(StopIteration): + next(iterator) + + def test_list_models_stop_iteration(self): + recorder = instrument_mlkit_service(status=200, payload=ONE_PAGE_LIST_RESPONSE) + page = mlkit.list_models() + assert len(recorder) == 1 + assert len(page.models) == 3 + iterator = page.iterate_all() + models = [model for model in iterator] + assert len(page.models) == 3 + with pytest.raises(StopIteration): + next(iterator) + assert len(models) == 3 + + def test_list_models_no_models(self): + recorder = instrument_mlkit_service(status=200, payload=NO_MODELS_LIST_RESPONSE) + page = mlkit.list_models() + assert len(recorder) == 1 + assert len(page.models) == 0 + models = [model for model in page.iterate_all()] + assert len(models) == 0 From 4618b1e8f4d193ca29f6823236617bbe206ba4ec Mon Sep 17 00:00:00 2001 From: ifielker Date: Thu, 29 Aug 2019 17:03:37 -0400 Subject: [PATCH 19/37] Implementation of Model, ModelFormat, TFLiteModelSource and subclasses (#335) * Implementation of Model, ModelFormat, ModelSource and subclasses --- firebase_admin/mlkit.py | 264 +++++++++++++++++++++++++++++++++++-- tests/test_mlkit.py | 282 ++++++++++++++++++++++++++++++++++++---- 2 files changed, 505 insertions(+), 41 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index dd02ea8db..3f1a825f6 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -18,6 +18,8 @@ deleting, publishing and unpublishing Firebase ML Kit models. """ +import datetime +import numbers import re import requests import six @@ -28,6 +30,12 @@ _MLKIT_ATTRIBUTE = '_mlkit' _MAX_PAGE_SIZE = 100 +_MODEL_ID_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') +_DISPLAY_NAME_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') +_TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') +_GCS_TFLITE_URI_PATTERN = re.compile(r'^gs://[a-z0-9_.-]{3,63}/.+') +_RESOURCE_NAME_PATTERN = re.compile( + r'^projects/(?P[^/]+)/models/(?P[A-Za-z0-9_-]{1,60})$') def _get_mlkit_service(app): @@ -47,7 +55,7 @@ def _get_mlkit_service(app): def get_model(model_id, app=None): mlkit_service = _get_mlkit_service(app) - return Model(mlkit_service.get_model(model_id)) + return Model.from_dict(mlkit_service.get_model(model_id)) def list_models(list_filter=None, page_size=None, page_token=None, app=None): @@ -62,14 +70,39 @@ def delete_model(model_id, app=None): class Model(object): - """A Firebase ML Kit Model object.""" - def __init__(self, data): - """Created from a data dictionary.""" - self._data = data + """A Firebase ML Kit Model object. + + Args: + display_name: The display name of your model - used to identify your model in code. + tags: Optional list of strings associated with your model. Can be used in list queries. + model_format: A subclass of ModelFormat. (e.g. TFLiteFormat) Specifies the model details. + """ + def __init__(self, display_name=None, tags=None, model_format=None): + self._data = {} + self._model_format = None + + if display_name is not None: + self.display_name = display_name + if tags is not None: + self.tags = tags + if model_format is not None: + self.model_format = model_format + + @classmethod + def from_dict(cls, data): + data_copy = dict(data) + tflite_format = None + tflite_format_data = data_copy.pop('tfliteModel', None) + if tflite_format_data: + tflite_format = TFLiteFormat.from_dict(tflite_format_data) + model = Model(model_format=tflite_format) + model._data = data_copy # pylint: disable=protected-access + return model def __eq__(self, other): if isinstance(other, self.__class__): - return self._data == other._data # pylint: disable=protected-access + # pylint: disable=protected-access + return self._data == other._data and self._model_format == other._model_format else: return False @@ -77,14 +110,182 @@ def __ne__(self, other): return not self.__eq__(other) @property - def name(self): - return self._data['name'] + def model_id(self): + if not self._data.get('name'): + return None + _, model_id = _validate_and_parse_name(self._data.get('name')) + return model_id @property def display_name(self): - return self._data['displayName'] + return self._data.get('displayName') + + @display_name.setter + def display_name(self, display_name): + self._data['displayName'] = _validate_display_name(display_name) + return self + + @property + def create_time(self): + """Returns the creation timestamp""" + seconds = self._data.get('createTime', {}).get('seconds') + if not isinstance(seconds, numbers.Number): + return None + + return datetime.datetime.fromtimestamp(float(seconds)) + + @property + def update_time(self): + """Returns the last update timestamp""" + seconds = self._data.get('updateTime', {}).get('seconds') + if not isinstance(seconds, numbers.Number): + return None - #TODO(ifielker): define the rest of the Model properties etc + return datetime.datetime.fromtimestamp(float(seconds)) + + @property + def validation_error(self): + return self._data.get('state', {}).get('validationError', {}).get('message') + + @property + def published(self): + return bool(self._data.get('state', {}).get('published')) + + @property + def etag(self): + return self._data.get('etag') + + @property + def model_hash(self): + return self._data.get('modelHash') + + @property + def tags(self): + return self._data.get('tags') + + @tags.setter + def tags(self, tags): + self._data['tags'] = _validate_tags(tags) + return self + + @property + def locked(self): + return bool(self._data.get('activeOperations') and + len(self._data.get('activeOperations')) > 0) + + @property + def model_format(self): + return self._model_format + + @model_format.setter + def model_format(self, model_format): + if model_format is not None: + _validate_model_format(model_format) + self._model_format = model_format #Can be None + return self + + def as_dict(self): + copy = dict(self._data) + if self._model_format: + copy.update(self._model_format.as_dict()) + return copy + + +class ModelFormat(object): + """Abstract base class representing a Model Format such as TFLite.""" + def as_dict(self): + raise NotImplementedError + + +class TFLiteFormat(ModelFormat): + """Model format representing a TFLite model. + + Args: + model_source: A TFLiteModelSource sub class. Specifies the details of the model source. + """ + def __init__(self, model_source=None): + self._data = {} + self._model_source = None + + if model_source is not None: + self.model_source = model_source + + @classmethod + def from_dict(cls, data): + data_copy = dict(data) + model_source = None + gcs_tflite_uri = data_copy.pop('gcsTfliteUri', None) + if gcs_tflite_uri: + model_source = TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri) + tflite_format = TFLiteFormat(model_source=model_source) + tflite_format._data = data_copy # pylint: disable=protected-access + return tflite_format + + + def __eq__(self, other): + if isinstance(other, self.__class__): + # pylint: disable=protected-access + return self._data == other._data and self._model_source == other._model_source + else: + return False + + def __ne__(self, other): + return not self.__eq__(other) + + @property + def model_source(self): + return self._model_source + + @model_source.setter + def model_source(self, model_source): + if model_source is not None: + if not isinstance(model_source, TFLiteModelSource): + raise TypeError('Model source must be a TFLiteModelSource object.') + self._model_source = model_source # Can be None + + @property + def size_bytes(self): + return self._data.get('sizeBytes') + + def as_dict(self): + copy = dict(self._data) + if self._model_source: + copy.update(self._model_source.as_dict()) + return {'tfliteModel': copy} + + +class TFLiteModelSource(object): + """Abstract base class representing a model source for TFLite format models.""" + def as_dict(self): + raise NotImplementedError + + +class TFLiteGCSModelSource(TFLiteModelSource): + """TFLite model source representing a tflite model file stored in GCS.""" + def __init__(self, gcs_tflite_uri): + self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access + else: + return False + + def __ne__(self, other): + return not self.__eq__(other) + + @property + def gcs_tflite_uri(self): + return self._gcs_tflite_uri + + @gcs_tflite_uri.setter + def gcs_tflite_uri(self, gcs_tflite_uri): + self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) + + def as_dict(self): + return {"gcsTfliteUri": self._gcs_tflite_uri} + + #TODO(ifielker): implement from_saved_model etc. class ListModelsPage(object): @@ -105,7 +306,7 @@ def __init__(self, list_models_func, list_filter, page_size, page_token): @property def models(self): """A list of Models from this page.""" - return [Model(model) for model in self._list_response.get('models', [])] + return [Model.from_dict(model) for model in self._list_response.get('models', [])] @property def list_filter(self): @@ -179,13 +380,48 @@ def __iter__(self): return self +def _validate_and_parse_name(name): + # The resource name is added automatically from API call responses. + # The only way it could be invalid is if someone tries to + # create a model from a dictionary manually and does it incorrectly. + matcher = _RESOURCE_NAME_PATTERN.match(name) + if not matcher: + raise ValueError('Model resource name format is invalid.') + return matcher.group('project_id'), matcher.group('model_id') + + def _validate_model_id(model_id): - if not isinstance(model_id, six.string_types): - raise TypeError('Model ID must be a string.') - if not re.match(r'^[A-Za-z0-9_-]{1,60}$', model_id): + if not _MODEL_ID_PATTERN.match(model_id): raise ValueError('Model ID format is invalid.') +def _validate_display_name(display_name): + if not _DISPLAY_NAME_PATTERN.match(display_name): + raise ValueError('Display name format is invalid.') + return display_name + + +def _validate_tags(tags): + if not isinstance(tags, list) or not \ + all(isinstance(tag, six.string_types) for tag in tags): + raise TypeError('Tags must be a list of strings.') + if not all(_TAG_PATTERN.match(tag) for tag in tags): + raise ValueError('Tag format is invalid.') + return tags + + +def _validate_gcs_tflite_uri(uri): + # GCS Bucket naming rules are complex. The regex is not comprehensive. + # See https://cloud.google.com/storage/docs/naming for full details. + if not _GCS_TFLITE_URI_PATTERN.match(uri): + raise ValueError('GCS TFLite URI format is invalid.') + return uri + +def _validate_model_format(model_format): + if not isinstance(model_format, ModelFormat): + raise TypeError('Model format must be a ModelFormat object.') + return model_format + def _validate_list_filter(list_filter): if list_filter is not None: if not isinstance(list_filter, six.string_types): diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index e14bd4371..c20982a2b 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -14,6 +14,7 @@ """Test cases for the firebase_admin.mlkit module.""" +import datetime import json import pytest @@ -27,6 +28,25 @@ PROJECT_ID = 'myProject1' PAGE_TOKEN = 'pageToken' NEXT_PAGE_TOKEN = 'nextPageToken' +CREATE_TIME_SECONDS = 1566426374 +CREATE_TIME_JSON = { + 'seconds': CREATE_TIME_SECONDS +} +CREATE_TIME_DATETIME = datetime.datetime.fromtimestamp(CREATE_TIME_SECONDS) + +UPDATE_TIME_SECONDS = 1566426678 +UPDATE_TIME_JSON = { + 'seconds': UPDATE_TIME_SECONDS +} +UPDATE_TIME_DATETIME = datetime.datetime.fromtimestamp(UPDATE_TIME_SECONDS) +ETAG = '33a64df551425fcc55e4d42a148795d9f25f89d4' +MODEL_HASH = '987987a98b98798d098098e09809fc0893897' +TAG_1 = 'Tag1' +TAG_2 = 'Tag2' +TAG_3 = 'Tag3' +TAGS = [TAG_1, TAG_2] +TAGS_2 = [TAG_1, TAG_3] + MODEL_ID_1 = 'modelId1' MODEL_NAME_1 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1) DISPLAY_NAME_1 = 'displayName1' @@ -34,7 +54,7 @@ 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_1 } -MODEL_1 = mlkit.Model(MODEL_JSON_1) +MODEL_1 = mlkit.Model.from_dict(MODEL_JSON_1) MODEL_ID_2 = 'modelId2' MODEL_NAME_2 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_2) @@ -43,7 +63,7 @@ 'name': MODEL_NAME_2, 'displayName': DISPLAY_NAME_2 } -MODEL_2 = mlkit.Model(MODEL_JSON_2) +MODEL_2 = mlkit.Model.from_dict(MODEL_JSON_2) MODEL_ID_3 = 'modelId3' MODEL_NAME_3 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_3) @@ -52,7 +72,69 @@ 'name': MODEL_NAME_3, 'displayName': DISPLAY_NAME_3 } -MODEL_3 = mlkit.Model(MODEL_JSON_3) +MODEL_3 = mlkit.Model.from_dict(MODEL_JSON_3) + +MODEL_STATE_PUBLISHED_JSON = { + 'published': True +} +VALIDATION_ERROR_CODE = 400 +VALIDATION_ERROR_MSG = 'No model format found for {0}.'.format(MODEL_ID_1) +MODEL_STATE_ERROR_JSON = { + 'validationError': { + 'code': VALIDATION_ERROR_CODE, + 'message': VALIDATION_ERROR_MSG, + } +} + +OPERATION_NOT_DONE_JSON_1 = { + 'name': 'operations/project/{0}/model/{1}/operation/123'.format(PROJECT_ID, MODEL_ID_1), + 'metadata': { + '@type': 'type.googleapis.com/google.firebase.ml.v1beta1.ModelOperationMetadata', + 'name': 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1), + 'basic_operation_status': 'BASIC_OPERATION_STATUS_UPLOADING' + } +} + +GCS_TFLITE_URI = 'gs://my_bucket/mymodel.tflite' +GCS_TFLITE_URI_JSON = {'gcsTfliteUri': GCS_TFLITE_URI} +GCS_TFLITE_MODEL_SOURCE = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI) +TFLITE_FORMAT_JSON = { + 'gcsTfliteUri': GCS_TFLITE_URI, + 'sizeBytes': '1234567' +} +TFLITE_FORMAT = mlkit.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON) + +GCS_TFLITE_URI_2 = 'gs://my_bucket/mymodel2.tflite' +GCS_TFLITE_URI_JSON_2 = {'gcsTfliteUri': GCS_TFLITE_URI_2} +GCS_TFLITE_MODEL_SOURCE_2 = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI_2) +TFLITE_FORMAT_JSON_2 = { + 'gcsTfliteUri': GCS_TFLITE_URI_2, + 'sizeBytes': '2345678' +} +TFLITE_FORMAT_2 = mlkit.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_2) + +FULL_MODEL_ERR_STATE_LRO_JSON = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_1, + 'createTime': CREATE_TIME_JSON, + 'updateTime': UPDATE_TIME_JSON, + 'state': MODEL_STATE_ERROR_JSON, + 'etag': ETAG, + 'modelHash': MODEL_HASH, + 'tags': TAGS, + 'activeOperations': [OPERATION_NOT_DONE_JSON_1], +} +FULL_MODEL_PUBLISHED_JSON = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_1, + 'createTime': CREATE_TIME_JSON, + 'updateTime': UPDATE_TIME_JSON, + 'state': MODEL_STATE_PUBLISHED_JSON, + 'etag': ETAG, + 'modelHash': MODEL_HASH, + 'tags': TAGS, + 'tfliteModel': TFLITE_FORMAT_JSON +} EMPTY_RESPONSE = json.dumps({}) DEFAULT_GET_RESPONSE = json.dumps(MODEL_JSON_1) @@ -93,29 +175,20 @@ ERROR_RESPONSE_BAD_REQUEST = json.dumps(ERROR_JSON_BAD_REQUEST) invalid_model_id_args = [ - ('', ValueError, 'Model ID format is invalid.'), - ('&_*#@:/?', ValueError, 'Model ID format is invalid.'), - (None, TypeError, 'Model ID must be a string.'), - (12345, TypeError, 'Model ID must be a string.'), + ('', ValueError), + ('&_*#@:/?', ValueError), + (None, TypeError), + (12345, TypeError), ] PAGE_SIZE_VALUE_ERROR_MSG = 'Page size must be a positive integer between ' \ '1 and {0}'.format(mlkit._MAX_PAGE_SIZE) -invalid_page_size_args = [ - ('abc', TypeError, 'Page size must be a number or None.'), - (4.2, TypeError, 'Page size must be a number or None.'), - (list(), TypeError, 'Page size must be a number or None.'), - (dict(), TypeError, 'Page size must be a number or None.'), - (True, TypeError, 'Page size must be a number or None.'), - (-1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), - (0, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), - (mlkit._MAX_PAGE_SIZE + 1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG) -] invalid_string_or_none_args = [0, -1, 4.2, 0x10, False, list(), dict()] -def check_error(err, err_type, msg): +def check_error(err, err_type, msg=None): assert isinstance(err, err_type) - assert str(err) == msg + if msg: + assert str(err) == msg def check_firebase_error(err, code, status, msg): @@ -138,6 +211,151 @@ def instrument_mlkit_service(app=None, status=200, payload=None): return recorder +class TestModel(object): + """Tests mlkit.Model class.""" + + def test_model_success_err_state_lro(self): + model = mlkit.Model.from_dict(FULL_MODEL_ERR_STATE_LRO_JSON) + assert model.model_id == MODEL_ID_1 + assert model.display_name == DISPLAY_NAME_1 + assert model.create_time == CREATE_TIME_DATETIME + assert model.update_time == UPDATE_TIME_DATETIME + assert model.validation_error == VALIDATION_ERROR_MSG + assert model.published is False + assert model.etag == ETAG + assert model.model_hash == MODEL_HASH + assert model.tags == TAGS + assert model.locked is True + assert model.model_format is None + assert model.as_dict() == FULL_MODEL_ERR_STATE_LRO_JSON + + def test_model_success_published(self): + model = mlkit.Model.from_dict(FULL_MODEL_PUBLISHED_JSON) + assert model.model_id == MODEL_ID_1 + assert model.display_name == DISPLAY_NAME_1 + assert model.create_time == CREATE_TIME_DATETIME + assert model.update_time == UPDATE_TIME_DATETIME + assert model.validation_error is None + assert model.published is True + assert model.etag == ETAG + assert model.model_hash == MODEL_HASH + assert model.tags == TAGS + assert model.locked is False + assert model.model_format == TFLITE_FORMAT + assert model.as_dict() == FULL_MODEL_PUBLISHED_JSON + + def test_model_keyword_based_creation_and_setters(self): + model = mlkit.Model(display_name=DISPLAY_NAME_1, tags=TAGS, model_format=TFLITE_FORMAT) + assert model.display_name == DISPLAY_NAME_1 + assert model.tags == TAGS + assert model.model_format == TFLITE_FORMAT + assert model.as_dict() == { + 'displayName': DISPLAY_NAME_1, + 'tags': TAGS, + 'tfliteModel': TFLITE_FORMAT_JSON + } + + model.display_name = DISPLAY_NAME_2 + model.tags = TAGS_2 + model.model_format = TFLITE_FORMAT_2 + assert model.as_dict() == { + 'displayName': DISPLAY_NAME_2, + 'tags': TAGS_2, + 'tfliteModel': TFLITE_FORMAT_JSON_2 + } + + def test_model_format_source_creation(self): + model_source = mlkit.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) + model_format = mlkit.TFLiteFormat(model_source=model_source) + model = mlkit.Model(display_name=DISPLAY_NAME_1, model_format=model_format) + assert model.as_dict() == { + 'displayName': DISPLAY_NAME_1, + 'tfliteModel': { + 'gcsTfliteUri': GCS_TFLITE_URI + } + } + + def test_model_source_setters(self): + model_source = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI) + model_source.gcs_tflite_uri = GCS_TFLITE_URI_2 + assert model_source.gcs_tflite_uri == GCS_TFLITE_URI_2 + assert model_source.as_dict() == GCS_TFLITE_URI_JSON_2 + + def test_model_format_setters(self): + model_format = mlkit.TFLiteFormat(model_source=GCS_TFLITE_MODEL_SOURCE) + model_format.model_source = GCS_TFLITE_MODEL_SOURCE_2 + assert model_format.model_source == GCS_TFLITE_MODEL_SOURCE_2 + assert model_format.as_dict() == { + 'tfliteModel': { + 'gcsTfliteUri': GCS_TFLITE_URI_2 + } + } + + @pytest.mark.parametrize('display_name, exc_type', [ + ('', ValueError), + ('&_*#@:/?', ValueError), + (12345, TypeError) + ]) + def test_model_display_name_validation_errors(self, display_name, exc_type): + with pytest.raises(exc_type) as err: + mlkit.Model(display_name=display_name) + check_error(err.value, exc_type) + + @pytest.mark.parametrize('tags, exc_type, error_message', [ + ('tag1', TypeError, 'Tags must be a list of strings.'), + (123, TypeError, 'Tags must be a list of strings.'), + (['tag1', 123, 'tag2'], TypeError, 'Tags must be a list of strings.'), + (['tag1', '@#$%^&'], ValueError, 'Tag format is invalid.'), + (['', 'tag2'], ValueError, 'Tag format is invalid.'), + (['sixty-one_characters_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx', + 'tag2'], ValueError, 'Tag format is invalid.') + ]) + def test_model_tags_validation_errors(self, tags, exc_type, error_message): + with pytest.raises(exc_type) as err: + mlkit.Model(tags=tags) + check_error(err.value, exc_type, error_message) + + @pytest.mark.parametrize('model_format', [ + 123, + "abc", + {}, + [], + True + ]) + def test_model_format_validation_errors(self, model_format): + with pytest.raises(TypeError) as err: + mlkit.Model(model_format=model_format) + check_error(err.value, TypeError, 'Model format must be a ModelFormat object.') + + @pytest.mark.parametrize('model_source', [ + 123, + "abc", + {}, + [], + True + ]) + def test_model_source_validation_errors(self, model_source): + with pytest.raises(TypeError) as err: + mlkit.TFLiteFormat(model_source=model_source) + check_error(err.value, TypeError, 'Model source must be a TFLiteModelSource object.') + + @pytest.mark.parametrize('uri, exc_type', [ + (123, TypeError), + ('abc', ValueError), + ('gs://NO_CAPITALS', ValueError), + ('gs://abc/', ValueError), + ('gs://aa/model.tflite', ValueError), + ('gs://@#$%/model.tflite', ValueError), + ('gs://invalid space/model.tflite', ValueError), + ('gs://sixty-four-characters_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx/model.tflite', + ValueError) + ]) + def test_gcs_tflite_source_validation_errors(self, uri, exc_type): + with pytest.raises(exc_type) as err: + mlkit.TFLiteGCSModelSource(gcs_tflite_uri=uri) + check_error(err.value, exc_type) + + class TestGetModel(object): """Tests mlkit.get_model.""" @classmethod @@ -160,14 +378,14 @@ def test_get_model(self): assert recorder[0].method == 'GET' assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1) assert model == MODEL_1 - assert model.name == MODEL_NAME_1 + assert model.model_id == MODEL_ID_1 assert model.display_name == DISPLAY_NAME_1 - @pytest.mark.parametrize('model_id, exc_type, error_message', invalid_model_id_args) - def test_get_model_validation_errors(self, model_id, exc_type, error_message): + @pytest.mark.parametrize('model_id, exc_type', invalid_model_id_args) + def test_get_model_validation_errors(self, model_id, exc_type): with pytest.raises(exc_type) as err: mlkit.get_model(model_id) - check_error(err.value, exc_type, error_message) + check_error(err.value, exc_type) def test_get_model_error(self): recorder = instrument_mlkit_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) @@ -190,6 +408,7 @@ def evaluate(): mlkit.get_model(MODEL_ID_1, app) testutils.run_without_project_id(evaluate) + class TestDeleteModel(object): """Tests mlkit.delete_model.""" @classmethod @@ -212,11 +431,11 @@ def test_delete_model(self): assert recorder[0].method == 'DELETE' assert recorder[0].url == TestDeleteModel._url(PROJECT_ID, MODEL_ID_1) - @pytest.mark.parametrize('model_id, exc_type, error_message', invalid_model_id_args) - def test_delete_model_validation_errors(self, model_id, exc_type, error_message): + @pytest.mark.parametrize('model_id, exc_type', invalid_model_id_args) + def test_delete_model_validation_errors(self, model_id, exc_type): with pytest.raises(exc_type) as err: mlkit.delete_model(model_id) - check_error(err.value, exc_type, error_message) + check_error(err.value, exc_type) def test_delete_model_error(self): recorder = instrument_mlkit_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) @@ -299,7 +518,16 @@ def test_list_models_list_filter_validation(self, list_filter): mlkit.list_models(list_filter=list_filter) check_error(err.value, TypeError, 'List filter must be a string or None.') - @pytest.mark.parametrize('page_size, exc_type, error_message', invalid_page_size_args) + @pytest.mark.parametrize('page_size, exc_type, error_message', [ + ('abc', TypeError, 'Page size must be a number or None.'), + (4.2, TypeError, 'Page size must be a number or None.'), + (list(), TypeError, 'Page size must be a number or None.'), + (dict(), TypeError, 'Page size must be a number or None.'), + (True, TypeError, 'Page size must be a number or None.'), + (-1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), + (0, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), + (mlkit._MAX_PAGE_SIZE + 1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG) + ]) def test_list_models_page_size_validation(self, page_size, exc_type, error_message): with pytest.raises(exc_type) as err: mlkit.list_models(page_size=page_size) From e5cf14a5b1776f09a9174c3f1e093a2f02ce264c Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 10 Sep 2019 21:03:01 -0400 Subject: [PATCH 20/37] Firebase ML Kit Create Model API implementation (#337) * create model plus long running operation handling * Model.wait_for_unlocked --- firebase_admin/_utils.py | 21 +++ firebase_admin/mlkit.py | 186 +++++++++++++++++++++- tests/test_mlkit.py | 327 +++++++++++++++++++++++++++++++++++---- 3 files changed, 494 insertions(+), 40 deletions(-) diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index 95ed2c414..fb6e32932 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -106,6 +106,27 @@ def handle_platform_error_from_requests(error, handle_func=None): return exc if exc else _handle_func_requests(error, message, error_dict) +def handle_operation_error(error): + """Constructs a ``FirebaseError`` from the given operation error. + + Args: + error: An error returned by a long running operation. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code. + """ + if not isinstance(error, dict): + return exceptions.UnknownError( + message='Unknown error while making a remote service call: {0}'.format(error), + cause=error) + + status_code = error.get('code') + message = error.get('message') + error_code = _http_status_to_error_code(status_code) + err_type = _error_code_to_exception_type(error_code) + return err_type(message=message) + + def _handle_func_requests(error, message, error_dict): """Constructs a ``FirebaseError`` from the given GCP error. diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 3f1a825f6..8cf8d1f7f 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -21,11 +21,14 @@ import datetime import numbers import re +import time import requests import six + from firebase_admin import _http_client from firebase_admin import _utils +from firebase_admin import exceptions _MLKIT_ATTRIBUTE = '_mlkit' @@ -36,6 +39,9 @@ _GCS_TFLITE_URI_PATTERN = re.compile(r'^gs://[a-z0-9_.-]{3,63}/.+') _RESOURCE_NAME_PATTERN = re.compile( r'^projects/(?P[^/]+)/models/(?P[A-Za-z0-9_-]{1,60})$') +_OPERATION_NAME_PATTERN = re.compile( + r'^operations/project/(?P[^/]+)/model/(?P[A-Za-z0-9_-]{1,60})' + + r'/operation/[^/]+$') def _get_mlkit_service(app): @@ -53,18 +59,60 @@ def _get_mlkit_service(app): return _utils.get_app_service(app, _MLKIT_ATTRIBUTE, _MLKitService) +def create_model(model, app=None): + """Creates a model in Firebase ML Kit. + + Args: + model: An mlkit.Model to create. + app: A Firebase app instance (or None to use the default app). + + Returns: + Model: The model that was created in Firebase ML Kit. + """ + mlkit_service = _get_mlkit_service(app) + return Model.from_dict(mlkit_service.create_model(model), app=app) + + def get_model(model_id, app=None): + """Gets a model from Firebase ML Kit. + + Args: + model_id: The id of the model to get. + app: A Firebase app instance (or None to use the default app). + + Returns: + Model: The requested model. + """ mlkit_service = _get_mlkit_service(app) - return Model.from_dict(mlkit_service.get_model(model_id)) + return Model.from_dict(mlkit_service.get_model(model_id), app=app) def list_models(list_filter=None, page_size=None, page_token=None, app=None): + """Lists models from Firebase ML Kit. + + Args: + list_filter: a list filter string such as "tags:'tag_1'". None will return all models. + page_size: A number between 1 and 100 inclusive that specifies the maximum + number of models to return per page. None for default. + page_token: A next page token returned from a previous page of results. None + for first page of results. + app: A Firebase app instance (or None to use the default app). + + Returns: + ListModelsPage: A (filtered) list of models. + """ mlkit_service = _get_mlkit_service(app) return ListModelsPage( - mlkit_service.list_models, list_filter, page_size, page_token) + mlkit_service.list_models, list_filter, page_size, page_token, app=app) def delete_model(model_id, app=None): + """Deletes a model from Firebase ML Kit. + + Args: + model_id: The id of the model you wish to delete. + app: A Firebase app instance (or None to use the default app). + """ mlkit_service = _get_mlkit_service(app) mlkit_service.delete_model(model_id) @@ -78,6 +126,7 @@ class Model(object): model_format: A subclass of ModelFormat. (e.g. TFLiteFormat) Specifies the model details. """ def __init__(self, display_name=None, tags=None, model_format=None): + self._app = None # Only needed for wait_for_unlo self._data = {} self._model_format = None @@ -89,7 +138,7 @@ def __init__(self, display_name=None, tags=None, model_format=None): self.model_format = model_format @classmethod - def from_dict(cls, data): + def from_dict(cls, data, app=None): data_copy = dict(data) tflite_format = None tflite_format_data = data_copy.pop('tfliteModel', None) @@ -97,8 +146,14 @@ def from_dict(cls, data): tflite_format = TFLiteFormat.from_dict(tflite_format_data) model = Model(model_format=tflite_format) model._data = data_copy # pylint: disable=protected-access + model._app = app # pylint: disable=protected-access return model + def _update_from_dict(self, data): + copy = Model.from_dict(data) + self.model_format = copy.model_format + self._data = copy._data # pylint: disable=protected-access + def __eq__(self, other): if isinstance(other, self.__class__): # pylint: disable=protected-access @@ -173,6 +228,26 @@ def locked(self): return bool(self._data.get('activeOperations') and len(self._data.get('activeOperations')) > 0) + def wait_for_unlocked(self, max_time_seconds=None): + """Waits for the model to be unlocked. (All active operations complete) + + Args: + max_time_seconds: The maximum number of seconds to wait for the model to unlock. + (None for no limit) + + Raises: + exceptions.DeadlineExceeded: If max_time_seconds passed and the model is still locked. + """ + if not self.locked: + return + mlkit_service = _get_mlkit_service(self._app) + op_name = self._data.get('activeOperations')[0].get('name') + model_dict = mlkit_service.handle_operation( + mlkit_service.get_operation(op_name), + wait_for_operation=True, + max_time_seconds=max_time_seconds) + self._update_from_dict(model_dict) + @property def model_format(self): return self._model_format @@ -296,17 +371,20 @@ class ListModelsPage(object): ``iterate_all()`` can be used to iterate through all the models in the Firebase project starting from this page. """ - def __init__(self, list_models_func, list_filter, page_size, page_token): + def __init__(self, list_models_func, list_filter, page_size, page_token, app): self._list_models_func = list_models_func self._list_filter = list_filter self._page_size = page_size self._page_token = page_token + self._app = app self._list_response = list_models_func(list_filter, page_size, page_token) @property def models(self): """A list of Models from this page.""" - return [Model.from_dict(model) for model in self._list_response.get('models', [])] + return [ + Model.from_dict(model, app=self._app) for model in self._list_response.get('models', []) + ] @property def list_filter(self): @@ -333,7 +411,8 @@ def get_next_page(self): self._list_models_func, self._list_filter, self._page_size, - self.next_page_token) + self.next_page_token, + self._app) return None def iterate_all(self): @@ -390,11 +469,25 @@ def _validate_and_parse_name(name): return matcher.group('project_id'), matcher.group('model_id') +def _validate_model(model): + if not isinstance(model, Model): + raise TypeError('Model must be an mlkit.Model.') + if not model.display_name: + raise ValueError('Model must have a display name.') + + def _validate_model_id(model_id): if not _MODEL_ID_PATTERN.match(model_id): raise ValueError('Model ID format is invalid.') +def _validate_and_parse_operation_name(op_name): + matcher = _OPERATION_NAME_PATTERN.match(op_name) + if not matcher: + raise ValueError('Operation name format is invalid.') + return matcher.group('project_id'), matcher.group('model_id') + + def _validate_display_name(display_name): if not _DISPLAY_NAME_PATTERN.match(display_name): raise ValueError('Display name format is invalid.') @@ -417,11 +510,13 @@ def _validate_gcs_tflite_uri(uri): raise ValueError('GCS TFLite URI format is invalid.') return uri + def _validate_model_format(model_format): if not isinstance(model_format, ModelFormat): raise TypeError('Model format must be a ModelFormat object.') return model_format + def _validate_list_filter(list_filter): if list_filter is not None: if not isinstance(list_filter, six.string_types): @@ -448,6 +543,9 @@ class _MLKitService(object): """Firebase MLKit service.""" PROJECT_URL = 'https://mlkit.googleapis.com/v1beta1/projects/{0}/' + OPERATION_URL = 'https://mlkit.googleapis.com/v1beta1/' + POLL_EXPONENTIAL_BACKOFF_FACTOR = 1.5 + POLL_BASE_WAIT_TIME_SECONDS = 3 def __init__(self, app): project_id = app.project_id @@ -459,6 +557,82 @@ def __init__(self, app): self._client = _http_client.JsonHttpClient( credential=app.credential.get_credential(), base_url=self._project_url) + self._operation_client = _http_client.JsonHttpClient( + credential=app.credential.get_credential(), + base_url=_MLKitService.OPERATION_URL) + + def get_operation(self, op_name): + _validate_and_parse_operation_name(op_name) + try: + return self._operation_client.body('get', url=op_name) + except requests.exceptions.RequestException as error: + raise _utils.handle_platform_error_from_requests(error) + + def _exponential_backoff(self, current_attempt, stop_time): + """Sleeps for the appropriate amount of time. Or throws deadline exceeded.""" + delay_factor = pow(_MLKitService.POLL_EXPONENTIAL_BACKOFF_FACTOR, current_attempt) + wait_time_seconds = delay_factor * _MLKitService.POLL_BASE_WAIT_TIME_SECONDS + + if stop_time is not None: + max_seconds_left = (stop_time - datetime.datetime.now()).total_seconds() + if max_seconds_left < 1: # allow a bit of time for rpc + raise exceptions.DeadlineExceededError('Polling max time exceeded.') + else: + wait_time_seconds = min(wait_time_seconds, max_seconds_left - 1) + time.sleep(wait_time_seconds) + + + def handle_operation(self, operation, wait_for_operation=False, max_time_seconds=None): + """Handles long running operations. + + Args: + operation: The operation to handle. + wait_for_operation: Should we allow polling for the operation to complete. + If no polling is requested, a locked model will be returned instead. + max_time_seconds: The maximum seconds to try polling for operation complete. + (None for no limit) + + Returns: + dict: A dictionary of the returned model properties. + + Raises: + TypeError: if the operation is not a dictionary. + ValueError: If the operation is malformed. + err: If the operation exceeds polling attempts or stop_time + """ + if not isinstance(operation, dict): + raise TypeError('Operation must be a dictionary.') + op_name = operation.get('name') + _, model_id = _validate_and_parse_operation_name(op_name) + + current_attempt = 0 + start_time = datetime.datetime.now() + stop_time = (None if max_time_seconds is None else + start_time + datetime.timedelta(seconds=max_time_seconds)) + while wait_for_operation and not operation.get('done'): + # We just got this operation. Wait before getting another + # so we don't exceed the GetOperation maximum request rate. + self._exponential_backoff(current_attempt, stop_time) + operation = self.get_operation(op_name) + current_attempt += 1 + + if operation.get('done'): + if operation.get('response'): + return operation.get('response') + elif operation.get('error'): + raise _utils.handle_operation_error(operation.get('error')) + + # If the operation is not complete or timed out, return a (locked) model instead + return get_model(model_id).as_dict() + + + def create_model(self, model): + _validate_model(model) + try: + return self.handle_operation( + self._client.body('post', url='models', json=model.as_dict())) + except requests.exceptions.RequestException as error: + raise _utils.handle_platform_error_from_requests(error) def get_model(self, model_id): _validate_model_id(model_id) diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index c20982a2b..78afbdf49 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -29,16 +29,24 @@ PAGE_TOKEN = 'pageToken' NEXT_PAGE_TOKEN = 'nextPageToken' CREATE_TIME_SECONDS = 1566426374 +CREATE_TIME_SECONDS_2 = 1566426385 CREATE_TIME_JSON = { 'seconds': CREATE_TIME_SECONDS } CREATE_TIME_DATETIME = datetime.datetime.fromtimestamp(CREATE_TIME_SECONDS) +CREATE_TIME_JSON_2 = { + 'seconds': CREATE_TIME_SECONDS_2 +} UPDATE_TIME_SECONDS = 1566426678 +UPDATE_TIME_SECONDS_2 = 1566426691 UPDATE_TIME_JSON = { 'seconds': UPDATE_TIME_SECONDS } UPDATE_TIME_DATETIME = datetime.datetime.fromtimestamp(UPDATE_TIME_SECONDS) +UPDATE_TIME_JSON_2 = { + 'seconds': UPDATE_TIME_SECONDS_2 +} ETAG = '33a64df551425fcc55e4d42a148795d9f25f89d4' MODEL_HASH = '987987a98b98798d098098e09809fc0893897' TAG_1 = 'Tag1' @@ -86,8 +94,9 @@ } } +OPERATION_NAME_1 = 'operations/project/{0}/model/{1}/operation/123'.format(PROJECT_ID, MODEL_ID_1) OPERATION_NOT_DONE_JSON_1 = { - 'name': 'operations/project/{0}/model/{1}/operation/123'.format(PROJECT_ID, MODEL_ID_1), + 'name': OPERATION_NAME_1, 'metadata': { '@type': 'type.googleapis.com/google.firebase.ml.v1beta1.ModelOperationMetadata', 'name': 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1), @@ -113,6 +122,64 @@ } TFLITE_FORMAT_2 = mlkit.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_2) +CREATED_MODEL_JSON_1 = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_1, + 'createTime': CREATE_TIME_JSON, + 'updateTime': UPDATE_TIME_JSON, + 'state': MODEL_STATE_ERROR_JSON, + 'etag': ETAG, + 'modelHash': MODEL_HASH, + 'tags': TAGS, +} +CREATED_MODEL_1 = mlkit.Model.from_dict(CREATED_MODEL_JSON_1) + +LOCKED_MODEL_JSON_1 = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_1, + 'createTime': CREATE_TIME_JSON, + 'updateTime': UPDATE_TIME_JSON, + 'tags': TAGS, + 'activeOperations': [OPERATION_NOT_DONE_JSON_1] +} + +LOCKED_MODEL_JSON_2 = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_2, + 'createTime': CREATE_TIME_JSON_2, + 'updateTime': UPDATE_TIME_JSON_2, + 'tags': TAGS_2, + 'activeOperations': [OPERATION_NOT_DONE_JSON_1] +} + +OPERATION_DONE_MODEL_JSON_1 = { + 'name': OPERATION_NAME_1, + 'done': True, + 'response': CREATED_MODEL_JSON_1 +} + +OPERATION_MALFORMED_JSON_1 = { + 'name': OPERATION_NAME_1, + 'done': True, + # if done is true then either response or error should be populated +} + +OPERATION_MISSING_NAME = { + 'done': False +} + +OPERATION_ERROR_CODE = 400 +OPERATION_ERROR_MSG = "Invalid argument" +OPERATION_ERROR_EXPECTED_STATUS = 'INVALID_ARGUMENT' +OPERATION_ERROR_JSON_1 = { + 'name': OPERATION_NAME_1, + 'done': True, + 'error': { + 'code': OPERATION_ERROR_CODE, + 'message': OPERATION_ERROR_MSG, + } +} + FULL_MODEL_ERR_STATE_LRO_JSON = { 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_1, @@ -135,9 +202,22 @@ 'tags': TAGS, 'tfliteModel': TFLITE_FORMAT_JSON } +FULL_MODEL_PUBLISHED = mlkit.Model.from_dict(FULL_MODEL_PUBLISHED_JSON) +OPERATION_DONE_FULL_MODEL_PUBLISHED_JSON = { + 'name': OPERATION_NAME_1, + 'done': True, + 'response': FULL_MODEL_PUBLISHED_JSON +} EMPTY_RESPONSE = json.dumps({}) +OPERATION_NOT_DONE_RESPONSE = json.dumps(OPERATION_NOT_DONE_JSON_1) +OPERATION_DONE_RESPONSE = json.dumps(OPERATION_DONE_MODEL_JSON_1) +OPERATION_DONE_PUBLISHED_RESPONSE = json.dumps(OPERATION_DONE_FULL_MODEL_PUBLISHED_JSON) +OPERATION_ERROR_RESPONSE = json.dumps(OPERATION_ERROR_JSON_1) +OPERATION_MALFORMED_RESPONSE = json.dumps(OPERATION_MALFORMED_JSON_1) +OPERATION_MISSING_NAME_RESPONSE = json.dumps(OPERATION_MISSING_NAME) DEFAULT_GET_RESPONSE = json.dumps(MODEL_JSON_1) +LOCKED_MODEL_2_RESPONSE = json.dumps(LOCKED_MODEL_JSON_2) NO_MODELS_LIST_RESPONSE = json.dumps({}) DEFAULT_LIST_RESPONSE = json.dumps({ 'models': [MODEL_JSON_1, MODEL_JSON_2], @@ -185,13 +265,25 @@ invalid_string_or_none_args = [0, -1, 4.2, 0x10, False, list(), dict()] -def check_error(err, err_type, msg=None): +# For validation type errors +def check_error(excinfo, err_type, msg=None): + err = excinfo.value assert isinstance(err, err_type) if msg: assert str(err) == msg -def check_firebase_error(err, code, status, msg): +# For errors that are returned in an operation +def check_operation_error(excinfo, code, msg): + err = excinfo.value + assert isinstance(err, exceptions.FirebaseError) + assert err.code == code + assert str(err) == msg + + +# For rpc errors +def check_firebase_error(excinfo, code, status, msg): + err = excinfo.value assert isinstance(err, exceptions.FirebaseError) assert err.code == code assert err.http_response is not None @@ -199,20 +291,43 @@ def check_firebase_error(err, code, status, msg): assert str(err) == msg -def instrument_mlkit_service(app=None, status=200, payload=None): +def instrument_mlkit_service(status=200, payload=None, operations=False, app=None): if not app: app = firebase_admin.get_app() mlkit_service = mlkit._get_mlkit_service(app) recorder = [] - mlkit_service._client.session.mount( - 'https://mlkit.googleapis.com', - testutils.MockAdapter(payload, status, recorder) - ) + session_url = 'https://mlkit.googleapis.com/v1beta1/' + + if isinstance(status, list): + adapter = testutils.MockMultiRequestAdapter + else: + adapter = testutils.MockAdapter + + if operations: + mlkit_service._operation_client.session.mount( + session_url, adapter(payload, status, recorder)) + else: + mlkit_service._client.session.mount( + session_url, adapter(payload, status, recorder)) return recorder class TestModel(object): """Tests mlkit.Model class.""" + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def _op_url(project_id, model_id): + return BASE_URL + \ + 'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id) def test_model_success_err_state_lro(self): model = mlkit.Model.from_dict(FULL_MODEL_ERR_STATE_LRO_JSON) @@ -297,9 +412,9 @@ def test_model_format_setters(self): (12345, TypeError) ]) def test_model_display_name_validation_errors(self, display_name, exc_type): - with pytest.raises(exc_type) as err: + with pytest.raises(exc_type) as excinfo: mlkit.Model(display_name=display_name) - check_error(err.value, exc_type) + check_error(excinfo, exc_type) @pytest.mark.parametrize('tags, exc_type, error_message', [ ('tag1', TypeError, 'Tags must be a list of strings.'), @@ -311,9 +426,9 @@ def test_model_display_name_validation_errors(self, display_name, exc_type): 'tag2'], ValueError, 'Tag format is invalid.') ]) def test_model_tags_validation_errors(self, tags, exc_type, error_message): - with pytest.raises(exc_type) as err: + with pytest.raises(exc_type) as excinfo: mlkit.Model(tags=tags) - check_error(err.value, exc_type, error_message) + check_error(excinfo, exc_type, error_message) @pytest.mark.parametrize('model_format', [ 123, @@ -323,9 +438,9 @@ def test_model_tags_validation_errors(self, tags, exc_type, error_message): True ]) def test_model_format_validation_errors(self, model_format): - with pytest.raises(TypeError) as err: + with pytest.raises(TypeError) as excinfo: mlkit.Model(model_format=model_format) - check_error(err.value, TypeError, 'Model format must be a ModelFormat object.') + check_error(excinfo, TypeError, 'Model format must be a ModelFormat object.') @pytest.mark.parametrize('model_source', [ 123, @@ -335,9 +450,9 @@ def test_model_format_validation_errors(self, model_format): True ]) def test_model_source_validation_errors(self, model_source): - with pytest.raises(TypeError) as err: + with pytest.raises(TypeError) as excinfo: mlkit.TFLiteFormat(model_source=model_source) - check_error(err.value, TypeError, 'Model source must be a TFLiteModelSource object.') + check_error(excinfo, TypeError, 'Model source must be a TFLiteModelSource object.') @pytest.mark.parametrize('uri, exc_type', [ (123, TypeError), @@ -351,9 +466,153 @@ def test_model_source_validation_errors(self, model_source): ValueError) ]) def test_gcs_tflite_source_validation_errors(self, uri, exc_type): - with pytest.raises(exc_type) as err: + with pytest.raises(exc_type) as excinfo: mlkit.TFLiteGCSModelSource(gcs_tflite_uri=uri) - check_error(err.value, exc_type) + check_error(excinfo, exc_type) + + def test_wait_for_unlocked_not_locked(self): + model = mlkit.Model(display_name="not_locked") + model.wait_for_unlocked() + + def test_wait_for_unlocked(self): + recorder = instrument_mlkit_service(status=200, + operations=True, + payload=OPERATION_DONE_PUBLISHED_RESPONSE) + model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_1) + model.wait_for_unlocked() + assert model == FULL_MODEL_PUBLISHED + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == TestModel._op_url(PROJECT_ID, MODEL_ID_1) + + def test_wait_for_unlocked_timeout(self): + recorder = instrument_mlkit_service( + status=200, operations=True, payload=OPERATION_NOT_DONE_RESPONSE) + mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 3 # longer so timeout applies immediately + model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_1) + with pytest.raises(Exception) as excinfo: + model.wait_for_unlocked(max_time_seconds=0.1) + check_error(excinfo, exceptions.DeadlineExceededError, 'Polling max time exceeded.') + assert len(recorder) == 1 + + +class TestCreateModel(object): + """Tests mlkit.create_model.""" + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def _url(project_id): + return BASE_URL + 'projects/{0}/models'.format(project_id) + + @staticmethod + def _op_url(project_id, model_id): + return BASE_URL + \ + 'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id) + + @staticmethod + def _get_url(project_id, model_id): + return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + + def test_immediate_done(self): + instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE) + model = mlkit.create_model(MODEL_1) + assert model == CREATED_MODEL_1 + + def test_returns_locked(self): + recorder = instrument_mlkit_service( + status=[200, 200], + payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE]) + expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) + model = mlkit.create_model(MODEL_1) + + assert model == expected_model + assert len(recorder) == 2 + assert recorder[0].method == 'POST' + assert recorder[0].url == TestCreateModel._url(PROJECT_ID) + assert recorder[1].method == 'GET' + assert recorder[1].url == TestCreateModel._get_url(PROJECT_ID, MODEL_ID_1) + + def test_operation_error(self): + instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE) + with pytest.raises(Exception) as excinfo: + mlkit.create_model(MODEL_1) + # The http request succeeded, the operation returned contains a create failure + check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) + + def test_malformed_operation(self): + recorder = instrument_mlkit_service( + status=[200, 200], + payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE]) + expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) + model = mlkit.create_model(MODEL_1) + assert model == expected_model + assert len(recorder) == 2 + assert recorder[0].method == 'POST' + assert recorder[0].url == TestCreateModel._url(PROJECT_ID) + assert recorder[1].method == 'GET' + assert recorder[1].url == TestCreateModel._get_url(PROJECT_ID, MODEL_ID_1) + + def test_rpc_error_create(self): + create_recorder = instrument_mlkit_service( + status=400, payload=ERROR_RESPONSE_BAD_REQUEST) + with pytest.raises(Exception) as excinfo: + mlkit.create_model(MODEL_1) + check_firebase_error( + excinfo, + ERROR_STATUS_BAD_REQUEST, + ERROR_CODE_BAD_REQUEST, + ERROR_MSG_BAD_REQUEST + ) + assert len(create_recorder) == 1 + + @pytest.mark.parametrize('model', [ + 'abc', + 4.2, + list(), + dict(), + True, + -1, + 0, + None + ]) + def test_not_model(self, model): + with pytest.raises(Exception) as excinfo: + mlkit.create_model(model) + check_error(excinfo, TypeError, 'Model must be an mlkit.Model.') + + def test_missing_display_name(self): + with pytest.raises(Exception) as excinfo: + mlkit.create_model(mlkit.Model.from_dict({})) + check_error(excinfo, ValueError, 'Model must have a display name.') + + def test_missing_op_name(self): + instrument_mlkit_service(status=200, payload=OPERATION_MISSING_NAME_RESPONSE) + with pytest.raises(Exception) as excinfo: + mlkit.create_model(MODEL_1) + check_error(excinfo, TypeError) + + @pytest.mark.parametrize('op_name', [ + 'abc', + '123', + 'projects/operations/project/1234/model/abc/operation/123', + 'operations/project/model/abc/operation/123', + 'operations/project/123/model/$#@/operation/123', + 'operations/project/1234/model/abc/operation/123/extrathing', + ]) + def test_invalid_op_name(self, op_name): + payload = json.dumps({'name': op_name}) + instrument_mlkit_service(status=200, payload=payload) + with pytest.raises(Exception) as excinfo: + mlkit.create_model(MODEL_1) + check_error(excinfo, ValueError, 'Operation name format is invalid.') class TestGetModel(object): @@ -383,16 +642,16 @@ def test_get_model(self): @pytest.mark.parametrize('model_id, exc_type', invalid_model_id_args) def test_get_model_validation_errors(self, model_id, exc_type): - with pytest.raises(exc_type) as err: + with pytest.raises(exc_type) as excinfo: mlkit.get_model(model_id) - check_error(err.value, exc_type) + check_error(excinfo, exc_type) def test_get_model_error(self): recorder = instrument_mlkit_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) - with pytest.raises(exceptions.NotFoundError) as err: + with pytest.raises(exceptions.NotFoundError) as excinfo: mlkit.get_model(MODEL_ID_1) check_firebase_error( - err.value, + excinfo, ERROR_STATUS_NOT_FOUND, ERROR_CODE_NOT_FOUND, ERROR_MSG_NOT_FOUND @@ -433,16 +692,16 @@ def test_delete_model(self): @pytest.mark.parametrize('model_id, exc_type', invalid_model_id_args) def test_delete_model_validation_errors(self, model_id, exc_type): - with pytest.raises(exc_type) as err: + with pytest.raises(exc_type) as excinfo: mlkit.delete_model(model_id) - check_error(err.value, exc_type) + check_error(excinfo, exc_type) def test_delete_model_error(self): recorder = instrument_mlkit_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) - with pytest.raises(exceptions.NotFoundError) as err: + with pytest.raises(exceptions.NotFoundError) as excinfo: mlkit.delete_model(MODEL_ID_1) check_firebase_error( - err.value, + excinfo, ERROR_STATUS_NOT_FOUND, ERROR_CODE_NOT_FOUND, ERROR_MSG_NOT_FOUND @@ -514,9 +773,9 @@ def test_list_models_with_all_args(self): @pytest.mark.parametrize('list_filter', invalid_string_or_none_args) def test_list_models_list_filter_validation(self, list_filter): - with pytest.raises(TypeError) as err: + with pytest.raises(TypeError) as excinfo: mlkit.list_models(list_filter=list_filter) - check_error(err.value, TypeError, 'List filter must be a string or None.') + check_error(excinfo, TypeError, 'List filter must be a string or None.') @pytest.mark.parametrize('page_size, exc_type, error_message', [ ('abc', TypeError, 'Page size must be a number or None.'), @@ -529,22 +788,22 @@ def test_list_models_list_filter_validation(self, list_filter): (mlkit._MAX_PAGE_SIZE + 1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG) ]) def test_list_models_page_size_validation(self, page_size, exc_type, error_message): - with pytest.raises(exc_type) as err: + with pytest.raises(exc_type) as excinfo: mlkit.list_models(page_size=page_size) - check_error(err.value, exc_type, error_message) + check_error(excinfo, exc_type, error_message) @pytest.mark.parametrize('page_token', invalid_string_or_none_args) def test_list_models_page_token_validation(self, page_token): - with pytest.raises(TypeError) as err: + with pytest.raises(TypeError) as excinfo: mlkit.list_models(page_token=page_token) - check_error(err.value, TypeError, 'Page token must be a string or None.') + check_error(excinfo, TypeError, 'Page token must be a string or None.') def test_list_models_error(self): recorder = instrument_mlkit_service(status=400, payload=ERROR_RESPONSE_BAD_REQUEST) - with pytest.raises(exceptions.InvalidArgumentError) as err: + with pytest.raises(exceptions.InvalidArgumentError) as excinfo: mlkit.list_models() check_firebase_error( - err.value, + excinfo, ERROR_STATUS_BAD_REQUEST, ERROR_CODE_BAD_REQUEST, ERROR_MSG_BAD_REQUEST From 2a3be7762065b04daec2c6b9baf0e0fdb62cec87 Mon Sep 17 00:00:00 2001 From: ifielker Date: Wed, 11 Sep 2019 13:46:21 -0400 Subject: [PATCH 21/37] Firebase ML Kit Update Model API implementation (#343) * Firebase ML Kit Create Model API implementation --- firebase_admin/mlkit.py | 29 +++++++- tests/test_mlkit.py | 161 ++++++++++++++++++++++++++++++++-------- 2 files changed, 156 insertions(+), 34 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 8cf8d1f7f..b9b56c8f4 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -73,6 +73,20 @@ def create_model(model, app=None): return Model.from_dict(mlkit_service.create_model(model), app=app) +def update_model(model, app=None): + """Updates a model in Firebase ML Kit. + + Args: + model: The mlkit.Model to update. + app: A Firebase app instance (or None to use the default app). + + Returns: + Model: The updated model. + """ + mlkit_service = _get_mlkit_service(app) + return Model.from_dict(mlkit_service.update_model(model), app=app) + + def get_model(model_id, app=None): """Gets a model from Firebase ML Kit. @@ -469,10 +483,10 @@ def _validate_and_parse_name(name): return matcher.group('project_id'), matcher.group('model_id') -def _validate_model(model): +def _validate_model(model, update_mask=None): if not isinstance(model, Model): raise TypeError('Model must be an mlkit.Model.') - if not model.display_name: + if update_mask is None and not model.display_name: raise ValueError('Model must have a display name.') @@ -634,6 +648,17 @@ def create_model(self, model): except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) + def update_model(self, model, update_mask=None): + _validate_model(model, update_mask) + data = {'model': model.as_dict()} + if update_mask is not None: + data['updateMask'] = update_mask + try: + return self.handle_operation( + self._client.body('patch', url='models/{0}'.format(model.model_id), json=data)) + except requests.exceptions.RequestException as error: + raise _utils.handle_platform_error_from_requests(error) + def get_model(self, model_id): _validate_model_id(model_id) try: diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 78afbdf49..e93bbd7e9 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -24,7 +24,6 @@ from tests import testutils BASE_URL = 'https://mlkit.googleapis.com/v1beta1/' - PROJECT_ID = 'myProject1' PAGE_TOKEN = 'pageToken' NEXT_PAGE_TOKEN = 'nextPageToken' @@ -122,7 +121,7 @@ } TFLITE_FORMAT_2 = mlkit.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_2) -CREATED_MODEL_JSON_1 = { +CREATED_UPDATED_MODEL_JSON_1 = { 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_1, 'createTime': CREATE_TIME_JSON, @@ -132,7 +131,7 @@ 'modelHash': MODEL_HASH, 'tags': TAGS, } -CREATED_MODEL_1 = mlkit.Model.from_dict(CREATED_MODEL_JSON_1) +CREATED_UPDATED_MODEL_1 = mlkit.Model.from_dict(CREATED_UPDATED_MODEL_JSON_1) LOCKED_MODEL_JSON_1 = { 'name': MODEL_NAME_1, @@ -155,19 +154,16 @@ OPERATION_DONE_MODEL_JSON_1 = { 'name': OPERATION_NAME_1, 'done': True, - 'response': CREATED_MODEL_JSON_1 + 'response': CREATED_UPDATED_MODEL_JSON_1 } - OPERATION_MALFORMED_JSON_1 = { 'name': OPERATION_NAME_1, 'done': True, # if done is true then either response or error should be populated } - OPERATION_MISSING_NAME = { 'done': False } - OPERATION_ERROR_CODE = 400 OPERATION_ERROR_MSG = "Invalid argument" OPERATION_ERROR_EXPECTED_STATUS = 'INVALID_ARGUMENT' @@ -254,15 +250,33 @@ } ERROR_RESPONSE_BAD_REQUEST = json.dumps(ERROR_JSON_BAD_REQUEST) -invalid_model_id_args = [ +INVALID_MODEL_ID_ARGS = [ ('', ValueError), ('&_*#@:/?', ValueError), (None, TypeError), (12345, TypeError), ] +INVALID_MODEL_ARGS = [ + 'abc', + 4.2, + list(), + dict(), + True, + -1, + 0, + None +] +INVALID_OP_NAME_ARGS = [ + 'abc', + '123', + 'projects/operations/project/1234/model/abc/operation/123', + 'operations/project/model/abc/operation/123', + 'operations/project/123/model/$#@/operation/123', + 'operations/project/1234/model/abc/operation/123/extrathing', +] PAGE_SIZE_VALUE_ERROR_MSG = 'Page size must be a positive integer between ' \ '1 and {0}'.format(mlkit._MAX_PAGE_SIZE) -invalid_string_or_none_args = [0, -1, 4.2, 0x10, False, list(), dict()] +INVALID_STRING_OR_NONE_ARGS = [0, -1, 4.2, 0x10, False, list(), dict()] # For validation type errors @@ -524,7 +538,7 @@ def _get_url(project_id, model_id): def test_immediate_done(self): instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE) model = mlkit.create_model(MODEL_1) - assert model == CREATED_MODEL_1 + assert model == CREATED_UPDATED_MODEL_1 def test_returns_locked(self): recorder = instrument_mlkit_service( @@ -573,16 +587,7 @@ def test_rpc_error_create(self): ) assert len(create_recorder) == 1 - @pytest.mark.parametrize('model', [ - 'abc', - 4.2, - list(), - dict(), - True, - -1, - 0, - None - ]) + @pytest.mark.parametrize('model', INVALID_MODEL_ARGS) def test_not_model(self, model): with pytest.raises(Exception) as excinfo: mlkit.create_model(model) @@ -599,14 +604,7 @@ def test_missing_op_name(self): mlkit.create_model(MODEL_1) check_error(excinfo, TypeError) - @pytest.mark.parametrize('op_name', [ - 'abc', - '123', - 'projects/operations/project/1234/model/abc/operation/123', - 'operations/project/model/abc/operation/123', - 'operations/project/123/model/$#@/operation/123', - 'operations/project/1234/model/abc/operation/123/extrathing', - ]) + @pytest.mark.parametrize('op_name', INVALID_OP_NAME_ARGS) def test_invalid_op_name(self, op_name): payload = json.dumps({'name': op_name}) instrument_mlkit_service(status=200, payload=payload) @@ -615,6 +613,105 @@ def test_invalid_op_name(self, op_name): check_error(excinfo, ValueError, 'Operation name format is invalid.') +class TestUpdateModel(object): + """Tests mlkit.update_model.""" + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def _url(project_id, model_id): + return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + + @staticmethod + def _op_url(project_id, model_id): + return BASE_URL + \ + 'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id) + + def test_immediate_done(self): + instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE) + model = mlkit.update_model(MODEL_1) + assert model == CREATED_UPDATED_MODEL_1 + + def test_returns_locked(self): + recorder = instrument_mlkit_service( + status=[200, 200], + payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE]) + expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) + model = mlkit.update_model(MODEL_1) + + assert model == expected_model + assert len(recorder) == 2 + assert recorder[0].method == 'PATCH' + assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[1].method == 'GET' + assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) + + def test_operation_error(self): + instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE) + with pytest.raises(Exception) as excinfo: + mlkit.update_model(MODEL_1) + # The http request succeeded, the operation returned contains a create failure + check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) + + def test_malformed_operation(self): + recorder = instrument_mlkit_service( + status=[200, 200], + payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE]) + expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) + model = mlkit.update_model(MODEL_1) + assert model == expected_model + assert len(recorder) == 2 + assert recorder[0].method == 'PATCH' + assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[1].method == 'GET' + assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) + + def test_rpc_error_create(self): + create_recorder = instrument_mlkit_service( + status=400, payload=ERROR_RESPONSE_BAD_REQUEST) + with pytest.raises(Exception) as excinfo: + mlkit.update_model(MODEL_1) + check_firebase_error( + excinfo, + ERROR_STATUS_BAD_REQUEST, + ERROR_CODE_BAD_REQUEST, + ERROR_MSG_BAD_REQUEST + ) + assert len(create_recorder) == 1 + + @pytest.mark.parametrize('model', INVALID_MODEL_ARGS) + def test_not_model(self, model): + with pytest.raises(Exception) as excinfo: + mlkit.update_model(model) + check_error(excinfo, TypeError, 'Model must be an mlkit.Model.') + + def test_missing_display_name(self): + with pytest.raises(Exception) as excinfo: + mlkit.update_model(mlkit.Model.from_dict({})) + check_error(excinfo, ValueError, 'Model must have a display name.') + + def test_missing_op_name(self): + instrument_mlkit_service(status=200, payload=OPERATION_MISSING_NAME_RESPONSE) + with pytest.raises(Exception) as excinfo: + mlkit.update_model(MODEL_1) + check_error(excinfo, TypeError) + + @pytest.mark.parametrize('op_name', INVALID_OP_NAME_ARGS) + def test_invalid_op_name(self, op_name): + payload = json.dumps({'name': op_name}) + instrument_mlkit_service(status=200, payload=payload) + with pytest.raises(Exception) as excinfo: + mlkit.update_model(MODEL_1) + check_error(excinfo, ValueError, 'Operation name format is invalid.') + + class TestGetModel(object): """Tests mlkit.get_model.""" @classmethod @@ -640,7 +737,7 @@ def test_get_model(self): assert model.model_id == MODEL_ID_1 assert model.display_name == DISPLAY_NAME_1 - @pytest.mark.parametrize('model_id, exc_type', invalid_model_id_args) + @pytest.mark.parametrize('model_id, exc_type', INVALID_MODEL_ID_ARGS) def test_get_model_validation_errors(self, model_id, exc_type): with pytest.raises(exc_type) as excinfo: mlkit.get_model(model_id) @@ -690,7 +787,7 @@ def test_delete_model(self): assert recorder[0].method == 'DELETE' assert recorder[0].url == TestDeleteModel._url(PROJECT_ID, MODEL_ID_1) - @pytest.mark.parametrize('model_id, exc_type', invalid_model_id_args) + @pytest.mark.parametrize('model_id, exc_type', INVALID_MODEL_ID_ARGS) def test_delete_model_validation_errors(self, model_id, exc_type): with pytest.raises(exc_type) as excinfo: mlkit.delete_model(model_id) @@ -771,7 +868,7 @@ def test_list_models_with_all_args(self): assert models_page.models[0] == MODEL_3 assert not models_page.has_next_page - @pytest.mark.parametrize('list_filter', invalid_string_or_none_args) + @pytest.mark.parametrize('list_filter', INVALID_STRING_OR_NONE_ARGS) def test_list_models_list_filter_validation(self, list_filter): with pytest.raises(TypeError) as excinfo: mlkit.list_models(list_filter=list_filter) @@ -792,7 +889,7 @@ def test_list_models_page_size_validation(self, page_size, exc_type, error_messa mlkit.list_models(page_size=page_size) check_error(excinfo, exc_type, error_message) - @pytest.mark.parametrize('page_token', invalid_string_or_none_args) + @pytest.mark.parametrize('page_token', INVALID_STRING_OR_NONE_ARGS) def test_list_models_page_token_validation(self, page_token): with pytest.raises(TypeError) as excinfo: mlkit.list_models(page_token=page_token) From 03441723bbb589782c40897bb8f01664bc0cc205 Mon Sep 17 00:00:00 2001 From: ifielker Date: Wed, 11 Sep 2019 19:55:03 -0400 Subject: [PATCH 22/37] Firebase ML Kit Publish and Unpublish Implementation (#345) * Firebase ML Kit Publish and Unpublish Implementation --- firebase_admin/mlkit.py | 46 ++++++++++++++++++-- tests/test_mlkit.py | 95 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 135 insertions(+), 6 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index b9b56c8f4..91cedbedc 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -87,6 +87,34 @@ def update_model(model, app=None): return Model.from_dict(mlkit_service.update_model(model), app=app) +def publish_model(model_id, app=None): + """Publishes a model in Firebase ML Kit. + + Args: + model_id: The id of the model to publish. + app: A Firebase app instance (or None to use the default app). + + Returns: + Model: The published model. + """ + mlkit_service = _get_mlkit_service(app) + return Model.from_dict(mlkit_service.set_published(model_id, publish=True), app=app) + + +def unpublish_model(model_id, app=None): + """Unpublishes a model in Firebase ML Kit. + + Args: + model_id: The id of the model to unpublish. + app: A Firebase app instance (or None to use the default app). + + Returns: + Model: The unpublished model. + """ + mlkit_service = _get_mlkit_service(app) + return Model.from_dict(mlkit_service.set_published(model_id, publish=False), app=app) + + def get_model(model_id, app=None): """Gets a model from Firebase ML Kit. @@ -562,12 +590,12 @@ class _MLKitService(object): POLL_BASE_WAIT_TIME_SECONDS = 3 def __init__(self, app): - project_id = app.project_id - if not project_id: + self._project_id = app.project_id + if not self._project_id: raise ValueError( 'Project ID is required to access MLKit service. Either set the ' 'projectId option, or use service account credentials.') - self._project_url = _MLKitService.PROJECT_URL.format(project_id) + self._project_url = _MLKitService.PROJECT_URL.format(self._project_id) self._client = _http_client.JsonHttpClient( credential=app.credential.get_credential(), base_url=self._project_url) @@ -595,7 +623,6 @@ def _exponential_backoff(self, current_attempt, stop_time): wait_time_seconds = min(wait_time_seconds, max_seconds_left - 1) time.sleep(wait_time_seconds) - def handle_operation(self, operation, wait_for_operation=False, max_time_seconds=None): """Handles long running operations. @@ -659,6 +686,17 @@ def update_model(self, model, update_mask=None): except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) + def set_published(self, model_id, publish): + _validate_model_id(model_id) + model_name = 'projects/{0}/models/{1}'.format(self._project_id, model_id) + model = Model.from_dict({ + 'name': model_name, + 'state': { + 'published': publish + } + }) + return self.update_model(model, update_mask='state.published') + def get_model(self, model_id): _validate_model_id(model_id) try: diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index e93bbd7e9..50fed4e1b 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -657,7 +657,7 @@ def test_operation_error(self): instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE) with pytest.raises(Exception) as excinfo: mlkit.update_model(MODEL_1) - # The http request succeeded, the operation returned contains a create failure + # The http request succeeded, the operation returned contains an update failure check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) def test_malformed_operation(self): @@ -673,7 +673,7 @@ def test_malformed_operation(self): assert recorder[1].method == 'GET' assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) - def test_rpc_error_create(self): + def test_rpc_error(self): create_recorder = instrument_mlkit_service( status=400, payload=ERROR_RESPONSE_BAD_REQUEST) with pytest.raises(Exception) as excinfo: @@ -712,6 +712,97 @@ def test_invalid_op_name(self, op_name): check_error(excinfo, ValueError, 'Operation name format is invalid.') +class TestPublishUnpublish(object): + """Tests mlkit.publish_model and mlkit.unpublish_model.""" + + PUBLISH_UNPUBLISH_WITH_ARGS = [ + (mlkit.publish_model, True), + (mlkit.unpublish_model, False) + ] + PUBLISH_UNPUBLISH_FUNCS = [item[0] for item in PUBLISH_UNPUBLISH_WITH_ARGS] + + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def _url(project_id, model_id): + return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + + @staticmethod + def _op_url(project_id, model_id): + return BASE_URL + \ + 'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id) + + @pytest.mark.parametrize('publish_function, published', PUBLISH_UNPUBLISH_WITH_ARGS) + def test_immediate_done(self, publish_function, published): + recorder = instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE) + model = publish_function(MODEL_ID_1) + assert model == CREATED_UPDATED_MODEL_1 + assert len(recorder) == 1 + assert recorder[0].method == 'PATCH' + assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) + body = json.loads(recorder[0].body.decode()) + assert body.get('model', {}).get('state', {}).get('published', None) is published + assert body.get('updateMask', {}) == 'state.published' + + @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) + def test_returns_locked(self, publish_function): + recorder = instrument_mlkit_service( + status=[200, 200], + payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE]) + expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) + model = publish_function(MODEL_ID_1) + + assert model == expected_model + assert len(recorder) == 2 + assert recorder[0].method == 'PATCH' + assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) + assert recorder[1].method == 'GET' + assert recorder[1].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) + + @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) + def test_operation_error(self, publish_function): + instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE) + with pytest.raises(Exception) as excinfo: + publish_function(MODEL_ID_1) + # The http request succeeded, the operation returned contains an update failure + check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) + + @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) + def test_malformed_operation(self, publish_function): + recorder = instrument_mlkit_service( + status=[200, 200], + payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE]) + expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) + model = publish_function(MODEL_ID_1) + assert model == expected_model + assert len(recorder) == 2 + assert recorder[0].method == 'PATCH' + assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) + assert recorder[1].method == 'GET' + assert recorder[1].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) + + @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) + def test_rpc_error(self, publish_function): + create_recorder = instrument_mlkit_service( + status=400, payload=ERROR_RESPONSE_BAD_REQUEST) + with pytest.raises(Exception) as excinfo: + publish_function(MODEL_ID_1) + check_firebase_error( + excinfo, + ERROR_STATUS_BAD_REQUEST, + ERROR_CODE_BAD_REQUEST, + ERROR_MSG_BAD_REQUEST + ) + assert len(create_recorder) == 1 + class TestGetModel(object): """Tests mlkit.get_model.""" @classmethod From cd5e82a88e8bcebf6ccf4d6aef40d05761f5f1e8 Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 17 Sep 2019 17:35:04 -0400 Subject: [PATCH 23/37] Firebase ML Kit TFLiteGCSModelSource.from_tflite_model implementation and conversion helpers (#346) * Firebase ML Kit TFLiteGCSModelSource.from_tflite_model implementation * support for tensorflow lite conversion helpers (version 1.x) --- firebase_admin/mlkit.py | 169 ++++++++++++++++++++++++++++++++++++---- tests/test_mlkit.py | 50 +++++++++++- 2 files changed, 205 insertions(+), 14 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 91cedbedc..8e78a26ce 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -18,6 +18,7 @@ deleting, publishing and unpublishing Firebase ML Kit models. """ + import datetime import numbers import re @@ -30,13 +31,27 @@ from firebase_admin import _utils from firebase_admin import exceptions +# pylint: disable=import-error,no-name-in-module +try: + from firebase_admin import storage + _GCS_ENABLED = True +except ImportError: + _GCS_ENABLED = False + +# pylint: disable=import-error,no-name-in-module +try: + import tensorflow as tf + _TF_ENABLED = True +except ImportError: + _TF_ENABLED = False _MLKIT_ATTRIBUTE = '_mlkit' _MAX_PAGE_SIZE = 100 _MODEL_ID_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') _DISPLAY_NAME_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') _TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') -_GCS_TFLITE_URI_PATTERN = re.compile(r'^gs://[a-z0-9_.-]{3,63}/.+') +_GCS_TFLITE_URI_PATTERN = re.compile( + r'^gs://(?P[a-z0-9_.-]{3,63})/(?P.+)$') _RESOURCE_NAME_PATTERN = re.compile( r'^projects/(?P[^/]+)/models/(?P[A-Za-z0-9_-]{1,60})$') _OPERATION_NAME_PATTERN = re.compile( @@ -301,16 +316,16 @@ def model_format(self, model_format): self._model_format = model_format #Can be None return self - def as_dict(self): + def as_dict(self, for_upload=False): copy = dict(self._data) if self._model_format: - copy.update(self._model_format.as_dict()) + copy.update(self._model_format.as_dict(for_upload=for_upload)) return copy class ModelFormat(object): """Abstract base class representing a Model Format such as TFLite.""" - def as_dict(self): + def as_dict(self, for_upload=False): raise NotImplementedError @@ -364,22 +379,70 @@ def model_source(self, model_source): def size_bytes(self): return self._data.get('sizeBytes') - def as_dict(self): + def as_dict(self, for_upload=False): copy = dict(self._data) if self._model_source: - copy.update(self._model_source.as_dict()) + copy.update(self._model_source.as_dict(for_upload=for_upload)) return {'tfliteModel': copy} class TFLiteModelSource(object): """Abstract base class representing a model source for TFLite format models.""" - def as_dict(self): + def as_dict(self, for_upload=False): raise NotImplementedError +class _CloudStorageClient(object): + """Cloud Storage helper class""" + + GCS_URI = 'gs://{0}/{1}' + BLOB_NAME = 'Firebase/MLKit/Models/{0}' + + @staticmethod + def _assert_gcs_enabled(): + if not _GCS_ENABLED: + raise ImportError('Failed to import the Cloud Storage library for Python. Make sure ' + 'to install the "google-cloud-storage" module.') + + @staticmethod + def _parse_gcs_tflite_uri(uri): + # GCS Bucket naming rules are complex. The regex is not comprehensive. + # See https://cloud.google.com/storage/docs/naming for full details. + matcher = _GCS_TFLITE_URI_PATTERN.match(uri) + if not matcher: + raise ValueError('GCS TFLite URI format is invalid.') + return matcher.group('bucket_name'), matcher.group('blob_name') + + @staticmethod + def upload(bucket_name, model_file_name, app): + _CloudStorageClient._assert_gcs_enabled() + bucket = storage.bucket(bucket_name, app=app) + blob_name = _CloudStorageClient.BLOB_NAME.format(model_file_name) + blob = bucket.blob(blob_name) + blob.upload_from_filename(model_file_name) + return _CloudStorageClient.GCS_URI.format(bucket.name, blob_name) + + @staticmethod + def sign_uri(gcs_tflite_uri, app): + """Makes the gcs_tflite_uri readable for GET for 10 minutes via signed_uri.""" + _CloudStorageClient._assert_gcs_enabled() + bucket_name, blob_name = _CloudStorageClient._parse_gcs_tflite_uri(gcs_tflite_uri) + bucket = storage.bucket(bucket_name, app=app) + blob = bucket.blob(blob_name) + return blob.generate_signed_url( + version='v4', + expiration=datetime.timedelta(minutes=10), + method='GET' + ) + + class TFLiteGCSModelSource(TFLiteModelSource): """TFLite model source representing a tflite model file stored in GCS.""" - def __init__(self, gcs_tflite_uri): + + _STORAGE_CLIENT = _CloudStorageClient() + + def __init__(self, gcs_tflite_uri, app=None): + self._app = app self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) def __eq__(self, other): @@ -391,6 +454,81 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + @classmethod + def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None): + """Uploads the model file to an existing Google Cloud Storage bucket. + + Args: + model_file_name: The name of the model file. + bucket_name: The name of an existing bucket. None to use the default bucket configured + in the app. + app: A Firebase app instance (or None to use the default app). + + Returns: + TFLiteGCSModelSource: The source created from the model_file + + Raises: + ImportError: If the Cloud Storage Library has not been installed. + """ + gcs_uri = TFLiteGCSModelSource._STORAGE_CLIENT.upload(bucket_name, model_file_name, app) + return TFLiteGCSModelSource(gcs_tflite_uri=gcs_uri, app=app) + + @staticmethod + def _assert_tf_version_1_enabled(): + if not _TF_ENABLED: + raise ImportError('Failed to import the tensorflow library for Python. Make sure ' + 'to install the tensorflow module.') + if not tf.VERSION.startswith('1.'): + raise ImportError('Expected tensorflow version 1.x, but found {0}'.format(tf.VERSION)) + + @classmethod + def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None): + """Creates a Tensor Flow Lite model from the saved model, and uploads the model to GCS. + + Args: + saved_model_dir: The saved model directory. + bucket_name: The name of an existing bucket. None to use the default bucket configured + in the app. + app: Optional. A Firebase app instance (or None to use the default app) + + Returns: + TFLiteGCSModelSource: The source created from the saved_model_dir + + Raises: + ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed. + """ + TFLiteGCSModelSource._assert_tf_version_1_enabled() + converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) + tflite_model = converter.convert() + open('firebase_mlkit_model.tflite', 'wb').write(tflite_model) + return TFLiteGCSModelSource.from_tflite_model_file( + 'firebase_mlkit_model.tflite', bucket_name, app) + + @classmethod + def from_keras_model(cls, keras_model, bucket_name=None, app=None): + """Creates a Tensor Flow Lite model from the keras model, and uploads the model to GCS. + + Args: + keras_model: A tf.keras model. + bucket_name: The name of an existing bucket. None to use the default bucket configured + in the app. + app: Optional. A Firebase app instance (or None to use the default app) + + Returns: + TFLiteGCSModelSource: The source created from the keras_model + + Raises: + ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed. + """ + TFLiteGCSModelSource._assert_tf_version_1_enabled() + keras_file = 'keras_model.h5' + tf.keras.models.save_model(keras_model, keras_file) + converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) + tflite_model = converter.convert() + open('firebase_mlkit_model.tflite', 'wb').write(tflite_model) + return TFLiteGCSModelSource.from_tflite_model_file( + 'firebase_mlkit_model.tflite', bucket_name, app) + @property def gcs_tflite_uri(self): return self._gcs_tflite_uri @@ -399,10 +537,15 @@ def gcs_tflite_uri(self): def gcs_tflite_uri(self, gcs_tflite_uri): self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) - def as_dict(self): - return {"gcsTfliteUri": self._gcs_tflite_uri} + def _get_signed_gcs_tflite_uri(self): + """Signs the GCS uri, so the model file can be uploaded to Firebase ML Kit and verified.""" + return TFLiteGCSModelSource._STORAGE_CLIENT.sign_uri(self._gcs_tflite_uri, self._app) + + def as_dict(self, for_upload=False): + if for_upload: + return {'gcsTfliteUri': self._get_signed_gcs_tflite_uri()} - #TODO(ifielker): implement from_saved_model etc. + return {'gcsTfliteUri': self._gcs_tflite_uri} class ListModelsPage(object): @@ -671,13 +814,13 @@ def create_model(self, model): _validate_model(model) try: return self.handle_operation( - self._client.body('post', url='models', json=model.as_dict())) + self._client.body('post', url='models', json=model.as_dict(for_upload=True))) except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) def update_model(self, model, update_mask=None): _validate_model(model, update_mask) - data = {'model': model.as_dict()} + data = {'model': model.as_dict(for_upload=True)} if update_mask is not None: data['updateMask'] = update_mask try: diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 50fed4e1b..26afdfa99 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -103,7 +103,9 @@ } } -GCS_TFLITE_URI = 'gs://my_bucket/mymodel.tflite' +GCS_BUCKET_NAME = 'my_bucket' +GCS_BLOB_NAME = 'mymodel.tflite' +GCS_TFLITE_URI = 'gs://{0}/{1}'.format(GCS_BUCKET_NAME, GCS_BLOB_NAME) GCS_TFLITE_URI_JSON = {'gcsTfliteUri': GCS_TFLITE_URI} GCS_TFLITE_MODEL_SOURCE = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI) TFLITE_FORMAT_JSON = { @@ -112,6 +114,10 @@ } TFLITE_FORMAT = mlkit.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON) +GCS_TFLITE_SIGNED_URI_PATTERN = ( + 'https://storage.googleapis.com/{0}/{1}?X-Goog-Algorithm=GOOG4-RSA-SHA256&foo') +GCS_TFLITE_SIGNED_URI = GCS_TFLITE_SIGNED_URI_PATTERN.format(GCS_BUCKET_NAME, GCS_BLOB_NAME) + GCS_TFLITE_URI_2 = 'gs://my_bucket/mymodel2.tflite' GCS_TFLITE_URI_JSON_2 = {'gcsTfliteUri': GCS_TFLITE_URI_2} GCS_TFLITE_MODEL_SOURCE_2 = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI_2) @@ -325,6 +331,18 @@ def instrument_mlkit_service(status=200, payload=None, operations=False, app=Non session_url, adapter(payload, status, recorder)) return recorder +class _TestStorageClient(object): + @staticmethod + def upload(bucket_name, model_file_name, app): + del app # unused variable + blob_name = mlkit._CloudStorageClient.BLOB_NAME.format(model_file_name) + return mlkit._CloudStorageClient.GCS_URI.format(bucket_name, blob_name) + + @staticmethod + def sign_uri(gcs_tflite_uri, app): + del app # unused variable + bucket_name, blob_name = mlkit._CloudStorageClient._parse_gcs_tflite_uri(gcs_tflite_uri) + return GCS_TFLITE_SIGNED_URI_PATTERN.format(bucket_name, blob_name) class TestModel(object): """Tests mlkit.Model class.""" @@ -333,6 +351,7 @@ def setup_class(cls): cred = testutils.MockCredential() firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test + mlkit.TFLiteGCSModelSource._STORAGE_CLIENT = _TestStorageClient() @classmethod def teardown_class(cls): @@ -404,6 +423,13 @@ def test_model_format_source_creation(self): } } + def test_source_creation_from_tflite_file(self): + model_source = mlkit.TFLiteGCSModelSource.from_tflite_model_file( + "my_model.tflite", "my_bucket") + assert model_source.as_dict() == { + 'gcsTfliteUri': 'gs://my_bucket/Firebase/MLKit/Models/my_model.tflite' + } + def test_model_source_setters(self): model_source = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI) model_source.gcs_tflite_uri = GCS_TFLITE_URI_2 @@ -420,6 +446,27 @@ def test_model_format_setters(self): } } + def test_model_as_dict_for_upload(self): + model_source = mlkit.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) + model_format = mlkit.TFLiteFormat(model_source=model_source) + model = mlkit.Model(display_name=DISPLAY_NAME_1, model_format=model_format) + assert model.as_dict(for_upload=True) == { + 'displayName': DISPLAY_NAME_1, + 'tfliteModel': { + 'gcsTfliteUri': GCS_TFLITE_SIGNED_URI + } + } + + @pytest.mark.parametrize('helper_func', [ + mlkit.TFLiteGCSModelSource.from_keras_model, + mlkit.TFLiteGCSModelSource.from_saved_model + ]) + def test_tf_not_enabled(self, helper_func): + mlkit._TF_ENABLED = False # for reliability + with pytest.raises(ImportError) as excinfo: + helper_func(None) + check_error(excinfo, ImportError) + @pytest.mark.parametrize('display_name, exc_type', [ ('', ValueError), ('&_*#@:/?', ValueError), @@ -803,6 +850,7 @@ def test_rpc_error(self, publish_function): ) assert len(create_recorder) == 1 + class TestGetModel(object): """Tests mlkit.get_model.""" @classmethod From 7b4731f1984705ba6d82cc6b462a21a695791755 Mon Sep 17 00:00:00 2001 From: Kevin Cheung Date: Mon, 18 Nov 2019 12:23:57 -0800 Subject: [PATCH 24/37] Quick pass at filling in missing docstrings (#367) * Quick pass at filling in missing docstrings * More punctuation --- firebase_admin/mlkit.py | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 8e78a26ce..a135d2d7f 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -196,6 +196,7 @@ def __init__(self, display_name=None, tags=None, model_format=None): @classmethod def from_dict(cls, data, app=None): + """Create an instance of the object from a dict.""" data_copy = dict(data) tflite_format = None tflite_format_data = data_copy.pop('tfliteModel', None) @@ -223,6 +224,7 @@ def __ne__(self, other): @property def model_id(self): + """The model's ID, unique to the project.""" if not self._data.get('name'): return None _, model_id = _validate_and_parse_name(self._data.get('name')) @@ -230,6 +232,8 @@ def model_id(self): @property def display_name(self): + """The model's display name, used to refer to the model in code and in + the Firebase console.""" return self._data.get('displayName') @display_name.setter @@ -239,7 +243,7 @@ def display_name(self, display_name): @property def create_time(self): - """Returns the creation timestamp""" + """The time the model was created.""" seconds = self._data.get('createTime', {}).get('seconds') if not isinstance(seconds, numbers.Number): return None @@ -248,7 +252,7 @@ def create_time(self): @property def update_time(self): - """Returns the last update timestamp""" + """The time the model was last updated.""" seconds = self._data.get('updateTime', {}).get('seconds') if not isinstance(seconds, numbers.Number): return None @@ -257,22 +261,28 @@ def update_time(self): @property def validation_error(self): + """Validation error message.""" return self._data.get('state', {}).get('validationError', {}).get('message') @property def published(self): + """True if the model is published and available for clients to + download.""" return bool(self._data.get('state', {}).get('published')) @property def etag(self): + """The entity tag (ETag) of the model resource.""" return self._data.get('etag') @property def model_hash(self): + """SHA256 hash of the model binary.""" return self._data.get('modelHash') @property def tags(self): + """Tag strings, used for filtering query results.""" return self._data.get('tags') @tags.setter @@ -282,6 +292,7 @@ def tags(self, tags): @property def locked(self): + """True if the Model object is locked by an active operation.""" return bool(self._data.get('activeOperations') and len(self._data.get('activeOperations')) > 0) @@ -307,6 +318,8 @@ def wait_for_unlocked(self, max_time_seconds=None): @property def model_format(self): + """The model's ``ModelFormat`` object, which represents the model's + format and storage location.""" return self._model_format @model_format.setter @@ -317,6 +330,7 @@ def model_format(self, model_format): return self def as_dict(self, for_upload=False): + """Returns a serializable representation of the object.""" copy = dict(self._data) if self._model_format: copy.update(self._model_format.as_dict(for_upload=for_upload)) @@ -326,6 +340,7 @@ def as_dict(self, for_upload=False): class ModelFormat(object): """Abstract base class representing a Model Format such as TFLite.""" def as_dict(self, for_upload=False): + """Returns a serializable representation of the object.""" raise NotImplementedError @@ -344,6 +359,7 @@ def __init__(self, model_source=None): @classmethod def from_dict(cls, data): + """Create an instance of the object from a dict.""" data_copy = dict(data) model_source = None gcs_tflite_uri = data_copy.pop('gcsTfliteUri', None) @@ -366,6 +382,7 @@ def __ne__(self, other): @property def model_source(self): + """The TF Lite model's location.""" return self._model_source @model_source.setter @@ -377,9 +394,11 @@ def model_source(self, model_source): @property def size_bytes(self): + """The size in bytes of the TF Lite model.""" return self._data.get('sizeBytes') def as_dict(self, for_upload=False): + """Returns a serializable representation of the object.""" copy = dict(self._data) if self._model_source: copy.update(self._model_source.as_dict(for_upload=for_upload)) @@ -389,6 +408,7 @@ def as_dict(self, for_upload=False): class TFLiteModelSource(object): """Abstract base class representing a model source for TFLite format models.""" def as_dict(self, for_upload=False): + """Returns a serializable representation of the object.""" raise NotImplementedError @@ -415,6 +435,7 @@ def _parse_gcs_tflite_uri(uri): @staticmethod def upload(bucket_name, model_file_name, app): + """Upload a model file to the specified Storage bucket.""" _CloudStorageClient._assert_gcs_enabled() bucket = storage.bucket(bucket_name, app=app) blob_name = _CloudStorageClient.BLOB_NAME.format(model_file_name) @@ -531,6 +552,7 @@ def from_keras_model(cls, keras_model, bucket_name=None, app=None): @property def gcs_tflite_uri(self): + """URI of the model file in Cloud Storage.""" return self._gcs_tflite_uri @gcs_tflite_uri.setter @@ -542,6 +564,7 @@ def _get_signed_gcs_tflite_uri(self): return TFLiteGCSModelSource._STORAGE_CLIENT.sign_uri(self._gcs_tflite_uri, self._app) def as_dict(self, for_upload=False): + """Returns a serializable representation of the object.""" if for_upload: return {'gcsTfliteUri': self._get_signed_gcs_tflite_uri()} @@ -578,11 +601,12 @@ def list_filter(self): @property def next_page_token(self): + """Token identifying the next page of results.""" return self._list_response.get('nextPageToken', '') @property def has_next_page(self): - """A boolean indicating whether more pages are available.""" + """True if more pages are available.""" return bool(self.next_page_token) def get_next_page(self): From c6dafbb354932379ffa3ad64c809a42f532a9a59 Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 10 Dec 2019 14:13:06 -0500 Subject: [PATCH 25/37] Modify Operation Handling to not require a name for Done Operations (#371) * Firebase ML Kit Modify Operation Handling to not require a name for Done Operations * Adding support for TensorFlow 2.x --- firebase_admin/mlkit.py | 100 ++++++++++++++++++++++++++-------------- tests/test_mlkit.py | 76 +++++++++++------------------- 2 files changed, 93 insertions(+), 83 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index a135d2d7f..bb277abf9 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -27,6 +27,7 @@ import six +from six.moves import urllib from firebase_admin import _http_client from firebase_admin import _utils from firebase_admin import exceptions @@ -200,6 +201,7 @@ def from_dict(cls, data, app=None): data_copy = dict(data) tflite_format = None tflite_format_data = data_copy.pop('tfliteModel', None) + data_copy.pop('@type', None) # Returned by Operations. (Not needed) if tflite_format_data: tflite_format = TFLiteFormat.from_dict(tflite_format_data) model = Model(model_format=tflite_format) @@ -495,12 +497,31 @@ def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None): return TFLiteGCSModelSource(gcs_tflite_uri=gcs_uri, app=app) @staticmethod - def _assert_tf_version_1_enabled(): + def _assert_tf_enabled(): if not _TF_ENABLED: raise ImportError('Failed to import the tensorflow library for Python. Make sure ' 'to install the tensorflow module.') - if not tf.VERSION.startswith('1.'): - raise ImportError('Expected tensorflow version 1.x, but found {0}'.format(tf.VERSION)) + if not tf.version.VERSION.startswith('1.') and not tf.version.VERSION.startswith('2.'): + raise ImportError('Expected tensorflow version 1.x or 2.x, but found {0}' + .format(tf.version.VERSION)) + + @staticmethod + def _tf_convert_from_saved_model(saved_model_dir): + # Same for both v1.x and v2.x + converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) + return converter.convert() + + @staticmethod + def _tf_convert_from_keras_model(keras_model): + # Version 1.x conversion function takes a model file. Version 2.x takes the model itself. + if tf.version.VERSION.startswith('1.'): + keras_file = 'firebase_keras_model.h5' + tf.keras.models.save_model(keras_model, keras_file) + converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) + return converter.convert() + else: + converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) + return converter.convert() @classmethod def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None): @@ -518,9 +539,8 @@ def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None): Raises: ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed. """ - TFLiteGCSModelSource._assert_tf_version_1_enabled() - converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) - tflite_model = converter.convert() + TFLiteGCSModelSource._assert_tf_enabled() + tflite_model = TFLiteGCSModelSource._tf_convert_from_saved_model(saved_model_dir) open('firebase_mlkit_model.tflite', 'wb').write(tflite_model) return TFLiteGCSModelSource.from_tflite_model_file( 'firebase_mlkit_model.tflite', bucket_name, app) @@ -541,11 +561,8 @@ def from_keras_model(cls, keras_model, bucket_name=None, app=None): Raises: ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed. """ - TFLiteGCSModelSource._assert_tf_version_1_enabled() - keras_file = 'keras_model.h5' - tf.keras.models.save_model(keras_model, keras_file) - converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) - tflite_model = converter.convert() + TFLiteGCSModelSource._assert_tf_enabled() + tflite_model = TFLiteGCSModelSource._tf_convert_from_keras_model(keras_model) open('firebase_mlkit_model.tflite', 'wb').write(tflite_model) return TFLiteGCSModelSource.from_tflite_model_file( 'firebase_mlkit_model.tflite', bucket_name, app) @@ -810,28 +827,36 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds """ if not isinstance(operation, dict): raise TypeError('Operation must be a dictionary.') - op_name = operation.get('name') - _, model_id = _validate_and_parse_operation_name(op_name) - - current_attempt = 0 - start_time = datetime.datetime.now() - stop_time = (None if max_time_seconds is None else - start_time + datetime.timedelta(seconds=max_time_seconds)) - while wait_for_operation and not operation.get('done'): - # We just got this operation. Wait before getting another - # so we don't exceed the GetOperation maximum request rate. - self._exponential_backoff(current_attempt, stop_time) - operation = self.get_operation(op_name) - current_attempt += 1 if operation.get('done'): + # Operations which are immediately done don't have an operation name if operation.get('response'): return operation.get('response') elif operation.get('error'): raise _utils.handle_operation_error(operation.get('error')) - - # If the operation is not complete or timed out, return a (locked) model instead - return get_model(model_id).as_dict() + raise exceptions.UnknownError(message='Internal Error: Malformed Operation.') + else: + op_name = operation.get('name') + _, model_id = _validate_and_parse_operation_name(op_name) + current_attempt = 0 + start_time = datetime.datetime.now() + stop_time = (None if max_time_seconds is None else + start_time + datetime.timedelta(seconds=max_time_seconds)) + while wait_for_operation and not operation.get('done'): + # We just got this operation. Wait before getting another + # so we don't exceed the GetOperation maximum request rate. + self._exponential_backoff(current_attempt, stop_time) + operation = self.get_operation(op_name) + current_attempt += 1 + + if operation.get('done'): + if operation.get('response'): + return operation.get('response') + elif operation.get('error'): + raise _utils.handle_operation_error(operation.get('error')) + + # If the operation is not complete or timed out, return a (locked) model instead + return get_model(model_id).as_dict() def create_model(self, model): @@ -844,12 +869,12 @@ def create_model(self, model): def update_model(self, model, update_mask=None): _validate_model(model, update_mask) - data = {'model': model.as_dict(for_upload=True)} + path = 'models/{0}'.format(model.model_id) if update_mask is not None: - data['updateMask'] = update_mask + path = path + '?updateMask={0}'.format(update_mask) try: return self.handle_operation( - self._client.body('patch', url='models/{0}'.format(model.model_id), json=data)) + self._client.body('patch', url=path, json=model.as_dict(for_upload=True))) except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) @@ -876,15 +901,20 @@ def list_models(self, list_filter, page_size, page_token): _validate_list_filter(list_filter) _validate_page_size(page_size) _validate_page_token(page_token) - payload = {} + params = {} if list_filter: - payload['list_filter'] = list_filter + params['filter'] = list_filter if page_size: - payload['page_size'] = page_size + params['page_size'] = page_size if page_token: - payload['page_token'] = page_token + params['page_token'] = page_token + path = 'models' + if params: + # pylint: disable=too-many-function-args + param_str = urllib.parse.urlencode(sorted(params.items()), True) + path = path + '?' + param_str try: - return self._client.body('get', url='models', json=payload) + return self._client.body('get', url=path) except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 26afdfa99..dbe590673 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -158,23 +158,21 @@ } OPERATION_DONE_MODEL_JSON_1 = { - 'name': OPERATION_NAME_1, 'done': True, 'response': CREATED_UPDATED_MODEL_JSON_1 } OPERATION_MALFORMED_JSON_1 = { - 'name': OPERATION_NAME_1, 'done': True, # if done is true then either response or error should be populated } OPERATION_MISSING_NAME = { + # Name is required if the operation is not done. 'done': False } OPERATION_ERROR_CODE = 400 OPERATION_ERROR_MSG = "Invalid argument" OPERATION_ERROR_EXPECTED_STATUS = 'INVALID_ARGUMENT' OPERATION_ERROR_JSON_1 = { - 'name': OPERATION_NAME_1, 'done': True, 'error': { 'code': OPERATION_ERROR_CODE, @@ -609,17 +607,10 @@ def test_operation_error(self): check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) def test_malformed_operation(self): - recorder = instrument_mlkit_service( - status=[200, 200], - payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE]) - expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) - model = mlkit.create_model(MODEL_1) - assert model == expected_model - assert len(recorder) == 2 - assert recorder[0].method == 'POST' - assert recorder[0].url == TestCreateModel._url(PROJECT_ID) - assert recorder[1].method == 'GET' - assert recorder[1].url == TestCreateModel._get_url(PROJECT_ID, MODEL_ID_1) + instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) + with pytest.raises(Exception) as excinfo: + mlkit.create_model(MODEL_1) + check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.') def test_rpc_error_create(self): create_recorder = instrument_mlkit_service( @@ -708,17 +699,10 @@ def test_operation_error(self): check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) def test_malformed_operation(self): - recorder = instrument_mlkit_service( - status=[200, 200], - payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE]) - expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) - model = mlkit.update_model(MODEL_1) - assert model == expected_model - assert len(recorder) == 2 - assert recorder[0].method == 'PATCH' - assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) - assert recorder[1].method == 'GET' - assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) + instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) + with pytest.raises(Exception) as excinfo: + mlkit.update_model(MODEL_1) + check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.') def test_rpc_error(self): create_recorder = instrument_mlkit_service( @@ -779,7 +763,13 @@ def teardown_class(cls): testutils.cleanup_apps() @staticmethod - def _url(project_id, model_id): + def _update_url(project_id, model_id): + update_url = 'projects/{0}/models/{1}?updateMask=state.published'.format( + project_id, model_id) + return BASE_URL + update_url + + @staticmethod + def _get_url(project_id, model_id): return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) @staticmethod @@ -794,10 +784,9 @@ def test_immediate_done(self, publish_function, published): assert model == CREATED_UPDATED_MODEL_1 assert len(recorder) == 1 assert recorder[0].method == 'PATCH' - assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1) body = json.loads(recorder[0].body.decode()) - assert body.get('model', {}).get('state', {}).get('published', None) is published - assert body.get('updateMask', {}) == 'state.published' + assert body.get('state', {}).get('published', None) is published @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_returns_locked(self, publish_function): @@ -810,9 +799,9 @@ def test_returns_locked(self, publish_function): assert model == expected_model assert len(recorder) == 2 assert recorder[0].method == 'PATCH' - assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1) assert recorder[1].method == 'GET' - assert recorder[1].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) + assert recorder[1].url == TestPublishUnpublish._get_url(PROJECT_ID, MODEL_ID_1) @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_operation_error(self, publish_function): @@ -824,17 +813,10 @@ def test_operation_error(self, publish_function): @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_malformed_operation(self, publish_function): - recorder = instrument_mlkit_service( - status=[200, 200], - payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE]) - expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) - model = publish_function(MODEL_ID_1) - assert model == expected_model - assert len(recorder) == 2 - assert recorder[0].method == 'PATCH' - assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) - assert recorder[1].method == 'GET' - assert recorder[1].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) + instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) + with pytest.raises(Exception) as excinfo: + publish_function(MODEL_ID_1) + check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.') @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_rpc_error(self, publish_function): @@ -996,12 +978,10 @@ def test_list_models_with_all_args(self): page_token=PAGE_TOKEN) assert len(recorder) == 1 assert recorder[0].method == 'GET' - assert recorder[0].url == TestListModels._url(PROJECT_ID) - assert json.loads(recorder[0].body.decode()) == { - 'list_filter': 'display_name=displayName3', - 'page_size': 10, - 'page_token': PAGE_TOKEN - } + assert recorder[0].url == ( + TestListModels._url(PROJECT_ID) + + '?filter=display_name%3DdisplayName3&page_size=10&page_token={0}' + .format(PAGE_TOKEN)) assert isinstance(models_page, mlkit.ListModelsPage) assert len(models_page.models) == 1 assert models_page.models[0] == MODEL_3 From 079c7e138bdcab09247459eaa1baed95d6d67ce9 Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 10 Dec 2019 14:55:14 -0500 Subject: [PATCH 26/37] rename from mlkit to ml (#373) --- firebase_admin/{mlkit.py => ml.py} | 100 +++++----- tests/{test_mlkit.py => test_ml.py} | 284 ++++++++++++++-------------- 2 files changed, 192 insertions(+), 192 deletions(-) rename firebase_admin/{mlkit.py => ml.py} (92%) rename tests/{test_mlkit.py => test_ml.py} (79%) diff --git a/firebase_admin/mlkit.py b/firebase_admin/ml.py similarity index 92% rename from firebase_admin/mlkit.py rename to firebase_admin/ml.py index bb277abf9..809ba9a41 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/ml.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Firebase ML Kit module. +"""Firebase ML module. This module contains functions for creating, updating, getting, listing, -deleting, publishing and unpublishing Firebase ML Kit models. +deleting, publishing and unpublishing Firebase ML models. """ @@ -46,7 +46,7 @@ except ImportError: _TF_ENABLED = False -_MLKIT_ATTRIBUTE = '_mlkit' +_ML_ATTRIBUTE = '_ml' _MAX_PAGE_SIZE = 100 _MODEL_ID_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') _DISPLAY_NAME_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') @@ -60,51 +60,51 @@ r'/operation/[^/]+$') -def _get_mlkit_service(app): - """ Returns an _MLKitService instance for an App. +def _get_ml_service(app): + """ Returns an _MLService instance for an App. Args: app: A Firebase App instance (or None to use the default App). Returns: - _MLKitService: An _MLKitService for the specified App instance. + _MLService: An _MLService for the specified App instance. Raises: ValueError: If the app argument is invalid. """ - return _utils.get_app_service(app, _MLKIT_ATTRIBUTE, _MLKitService) + return _utils.get_app_service(app, _ML_ATTRIBUTE, _MLService) def create_model(model, app=None): - """Creates a model in Firebase ML Kit. + """Creates a model in Firebase ML. Args: - model: An mlkit.Model to create. + model: An ml.Model to create. app: A Firebase app instance (or None to use the default app). Returns: - Model: The model that was created in Firebase ML Kit. + Model: The model that was created in Firebase ML. """ - mlkit_service = _get_mlkit_service(app) - return Model.from_dict(mlkit_service.create_model(model), app=app) + ml_service = _get_ml_service(app) + return Model.from_dict(ml_service.create_model(model), app=app) def update_model(model, app=None): - """Updates a model in Firebase ML Kit. + """Updates a model in Firebase ML. Args: - model: The mlkit.Model to update. + model: The ml.Model to update. app: A Firebase app instance (or None to use the default app). Returns: Model: The updated model. """ - mlkit_service = _get_mlkit_service(app) - return Model.from_dict(mlkit_service.update_model(model), app=app) + ml_service = _get_ml_service(app) + return Model.from_dict(ml_service.update_model(model), app=app) def publish_model(model_id, app=None): - """Publishes a model in Firebase ML Kit. + """Publishes a model in Firebase ML. Args: model_id: The id of the model to publish. @@ -113,12 +113,12 @@ def publish_model(model_id, app=None): Returns: Model: The published model. """ - mlkit_service = _get_mlkit_service(app) - return Model.from_dict(mlkit_service.set_published(model_id, publish=True), app=app) + ml_service = _get_ml_service(app) + return Model.from_dict(ml_service.set_published(model_id, publish=True), app=app) def unpublish_model(model_id, app=None): - """Unpublishes a model in Firebase ML Kit. + """Unpublishes a model in Firebase ML. Args: model_id: The id of the model to unpublish. @@ -127,12 +127,12 @@ def unpublish_model(model_id, app=None): Returns: Model: The unpublished model. """ - mlkit_service = _get_mlkit_service(app) - return Model.from_dict(mlkit_service.set_published(model_id, publish=False), app=app) + ml_service = _get_ml_service(app) + return Model.from_dict(ml_service.set_published(model_id, publish=False), app=app) def get_model(model_id, app=None): - """Gets a model from Firebase ML Kit. + """Gets a model from Firebase ML. Args: model_id: The id of the model to get. @@ -141,12 +141,12 @@ def get_model(model_id, app=None): Returns: Model: The requested model. """ - mlkit_service = _get_mlkit_service(app) - return Model.from_dict(mlkit_service.get_model(model_id), app=app) + ml_service = _get_ml_service(app) + return Model.from_dict(ml_service.get_model(model_id), app=app) def list_models(list_filter=None, page_size=None, page_token=None, app=None): - """Lists models from Firebase ML Kit. + """Lists models from Firebase ML. Args: list_filter: a list filter string such as "tags:'tag_1'". None will return all models. @@ -159,24 +159,24 @@ def list_models(list_filter=None, page_size=None, page_token=None, app=None): Returns: ListModelsPage: A (filtered) list of models. """ - mlkit_service = _get_mlkit_service(app) + ml_service = _get_ml_service(app) return ListModelsPage( - mlkit_service.list_models, list_filter, page_size, page_token, app=app) + ml_service.list_models, list_filter, page_size, page_token, app=app) def delete_model(model_id, app=None): - """Deletes a model from Firebase ML Kit. + """Deletes a model from Firebase ML. Args: model_id: The id of the model you wish to delete. app: A Firebase app instance (or None to use the default app). """ - mlkit_service = _get_mlkit_service(app) - mlkit_service.delete_model(model_id) + ml_service = _get_ml_service(app) + ml_service.delete_model(model_id) class Model(object): - """A Firebase ML Kit Model object. + """A Firebase ML Model object. Args: display_name: The display name of your model - used to identify your model in code. @@ -310,10 +310,10 @@ def wait_for_unlocked(self, max_time_seconds=None): """ if not self.locked: return - mlkit_service = _get_mlkit_service(self._app) + ml_service = _get_ml_service(self._app) op_name = self._data.get('activeOperations')[0].get('name') - model_dict = mlkit_service.handle_operation( - mlkit_service.get_operation(op_name), + model_dict = ml_service.handle_operation( + ml_service.get_operation(op_name), wait_for_operation=True, max_time_seconds=max_time_seconds) self._update_from_dict(model_dict) @@ -418,7 +418,7 @@ class _CloudStorageClient(object): """Cloud Storage helper class""" GCS_URI = 'gs://{0}/{1}' - BLOB_NAME = 'Firebase/MLKit/Models/{0}' + BLOB_NAME = 'Firebase/ML/Models/{0}' @staticmethod def _assert_gcs_enabled(): @@ -541,9 +541,9 @@ def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None): """ TFLiteGCSModelSource._assert_tf_enabled() tflite_model = TFLiteGCSModelSource._tf_convert_from_saved_model(saved_model_dir) - open('firebase_mlkit_model.tflite', 'wb').write(tflite_model) + open('firebase_ml_model.tflite', 'wb').write(tflite_model) return TFLiteGCSModelSource.from_tflite_model_file( - 'firebase_mlkit_model.tflite', bucket_name, app) + 'firebase_ml_model.tflite', bucket_name, app) @classmethod def from_keras_model(cls, keras_model, bucket_name=None, app=None): @@ -563,9 +563,9 @@ def from_keras_model(cls, keras_model, bucket_name=None, app=None): """ TFLiteGCSModelSource._assert_tf_enabled() tflite_model = TFLiteGCSModelSource._tf_convert_from_keras_model(keras_model) - open('firebase_mlkit_model.tflite', 'wb').write(tflite_model) + open('firebase_ml_model.tflite', 'wb').write(tflite_model) return TFLiteGCSModelSource.from_tflite_model_file( - 'firebase_mlkit_model.tflite', bucket_name, app) + 'firebase_ml_model.tflite', bucket_name, app) @property def gcs_tflite_uri(self): @@ -577,7 +577,7 @@ def gcs_tflite_uri(self, gcs_tflite_uri): self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) def _get_signed_gcs_tflite_uri(self): - """Signs the GCS uri, so the model file can be uploaded to Firebase ML Kit and verified.""" + """Signs the GCS uri, so the model file can be uploaded to Firebase ML and verified.""" return TFLiteGCSModelSource._STORAGE_CLIENT.sign_uri(self._gcs_tflite_uri, self._app) def as_dict(self, for_upload=False): @@ -697,7 +697,7 @@ def _validate_and_parse_name(name): def _validate_model(model, update_mask=None): if not isinstance(model, Model): - raise TypeError('Model must be an mlkit.Model.') + raise TypeError('Model must be an ml.Model.') if update_mask is None and not model.display_name: raise ValueError('Model must have a display name.') @@ -765,8 +765,8 @@ def _validate_page_token(page_token): raise TypeError('Page token must be a string or None.') -class _MLKitService(object): - """Firebase MLKit service.""" +class _MLService(object): + """Firebase ML service.""" PROJECT_URL = 'https://mlkit.googleapis.com/v1beta1/projects/{0}/' OPERATION_URL = 'https://mlkit.googleapis.com/v1beta1/' @@ -777,15 +777,15 @@ def __init__(self, app): self._project_id = app.project_id if not self._project_id: raise ValueError( - 'Project ID is required to access MLKit service. Either set the ' + 'Project ID is required to access ML service. Either set the ' 'projectId option, or use service account credentials.') - self._project_url = _MLKitService.PROJECT_URL.format(self._project_id) + self._project_url = _MLService.PROJECT_URL.format(self._project_id) self._client = _http_client.JsonHttpClient( credential=app.credential.get_credential(), base_url=self._project_url) self._operation_client = _http_client.JsonHttpClient( credential=app.credential.get_credential(), - base_url=_MLKitService.OPERATION_URL) + base_url=_MLService.OPERATION_URL) def get_operation(self, op_name): _validate_and_parse_operation_name(op_name) @@ -796,8 +796,8 @@ def get_operation(self, op_name): def _exponential_backoff(self, current_attempt, stop_time): """Sleeps for the appropriate amount of time. Or throws deadline exceeded.""" - delay_factor = pow(_MLKitService.POLL_EXPONENTIAL_BACKOFF_FACTOR, current_attempt) - wait_time_seconds = delay_factor * _MLKitService.POLL_BASE_WAIT_TIME_SECONDS + delay_factor = pow(_MLService.POLL_EXPONENTIAL_BACKOFF_FACTOR, current_attempt) + wait_time_seconds = delay_factor * _MLService.POLL_BASE_WAIT_TIME_SECONDS if stop_time is not None: max_seconds_left = (stop_time - datetime.datetime.now()).total_seconds() @@ -897,7 +897,7 @@ def get_model(self, model_id): raise _utils.handle_platform_error_from_requests(error) def list_models(self, list_filter, page_size, page_token): - """ lists Firebase ML Kit models.""" + """ lists Firebase ML models.""" _validate_list_filter(list_filter) _validate_page_size(page_size) _validate_page_token(page_token) diff --git a/tests/test_mlkit.py b/tests/test_ml.py similarity index 79% rename from tests/test_mlkit.py rename to tests/test_ml.py index dbe590673..e66507e88 100644 --- a/tests/test_mlkit.py +++ b/tests/test_ml.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Test cases for the firebase_admin.mlkit module.""" +"""Test cases for the firebase_admin.ml module.""" import datetime import json @@ -20,7 +20,7 @@ import firebase_admin from firebase_admin import exceptions -from firebase_admin import mlkit +from firebase_admin import ml from tests import testutils BASE_URL = 'https://mlkit.googleapis.com/v1beta1/' @@ -61,7 +61,7 @@ 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_1 } -MODEL_1 = mlkit.Model.from_dict(MODEL_JSON_1) +MODEL_1 = ml.Model.from_dict(MODEL_JSON_1) MODEL_ID_2 = 'modelId2' MODEL_NAME_2 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_2) @@ -70,7 +70,7 @@ 'name': MODEL_NAME_2, 'displayName': DISPLAY_NAME_2 } -MODEL_2 = mlkit.Model.from_dict(MODEL_JSON_2) +MODEL_2 = ml.Model.from_dict(MODEL_JSON_2) MODEL_ID_3 = 'modelId3' MODEL_NAME_3 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_3) @@ -79,7 +79,7 @@ 'name': MODEL_NAME_3, 'displayName': DISPLAY_NAME_3 } -MODEL_3 = mlkit.Model.from_dict(MODEL_JSON_3) +MODEL_3 = ml.Model.from_dict(MODEL_JSON_3) MODEL_STATE_PUBLISHED_JSON = { 'published': True @@ -107,12 +107,12 @@ GCS_BLOB_NAME = 'mymodel.tflite' GCS_TFLITE_URI = 'gs://{0}/{1}'.format(GCS_BUCKET_NAME, GCS_BLOB_NAME) GCS_TFLITE_URI_JSON = {'gcsTfliteUri': GCS_TFLITE_URI} -GCS_TFLITE_MODEL_SOURCE = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI) +GCS_TFLITE_MODEL_SOURCE = ml.TFLiteGCSModelSource(GCS_TFLITE_URI) TFLITE_FORMAT_JSON = { 'gcsTfliteUri': GCS_TFLITE_URI, 'sizeBytes': '1234567' } -TFLITE_FORMAT = mlkit.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON) +TFLITE_FORMAT = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON) GCS_TFLITE_SIGNED_URI_PATTERN = ( 'https://storage.googleapis.com/{0}/{1}?X-Goog-Algorithm=GOOG4-RSA-SHA256&foo') @@ -120,12 +120,12 @@ GCS_TFLITE_URI_2 = 'gs://my_bucket/mymodel2.tflite' GCS_TFLITE_URI_JSON_2 = {'gcsTfliteUri': GCS_TFLITE_URI_2} -GCS_TFLITE_MODEL_SOURCE_2 = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI_2) +GCS_TFLITE_MODEL_SOURCE_2 = ml.TFLiteGCSModelSource(GCS_TFLITE_URI_2) TFLITE_FORMAT_JSON_2 = { 'gcsTfliteUri': GCS_TFLITE_URI_2, 'sizeBytes': '2345678' } -TFLITE_FORMAT_2 = mlkit.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_2) +TFLITE_FORMAT_2 = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_2) CREATED_UPDATED_MODEL_JSON_1 = { 'name': MODEL_NAME_1, @@ -137,7 +137,7 @@ 'modelHash': MODEL_HASH, 'tags': TAGS, } -CREATED_UPDATED_MODEL_1 = mlkit.Model.from_dict(CREATED_UPDATED_MODEL_JSON_1) +CREATED_UPDATED_MODEL_1 = ml.Model.from_dict(CREATED_UPDATED_MODEL_JSON_1) LOCKED_MODEL_JSON_1 = { 'name': MODEL_NAME_1, @@ -202,7 +202,7 @@ 'tags': TAGS, 'tfliteModel': TFLITE_FORMAT_JSON } -FULL_MODEL_PUBLISHED = mlkit.Model.from_dict(FULL_MODEL_PUBLISHED_JSON) +FULL_MODEL_PUBLISHED = ml.Model.from_dict(FULL_MODEL_PUBLISHED_JSON) OPERATION_DONE_FULL_MODEL_PUBLISHED_JSON = { 'name': OPERATION_NAME_1, 'done': True, @@ -279,7 +279,7 @@ 'operations/project/1234/model/abc/operation/123/extrathing', ] PAGE_SIZE_VALUE_ERROR_MSG = 'Page size must be a positive integer between ' \ - '1 and {0}'.format(mlkit._MAX_PAGE_SIZE) + '1 and {0}'.format(ml._MAX_PAGE_SIZE) INVALID_STRING_OR_NONE_ARGS = [0, -1, 4.2, 0x10, False, list(), dict()] @@ -309,10 +309,10 @@ def check_firebase_error(excinfo, code, status, msg): assert str(err) == msg -def instrument_mlkit_service(status=200, payload=None, operations=False, app=None): +def instrument_ml_service(status=200, payload=None, operations=False, app=None): if not app: app = firebase_admin.get_app() - mlkit_service = mlkit._get_mlkit_service(app) + ml_service = ml._get_ml_service(app) recorder = [] session_url = 'https://mlkit.googleapis.com/v1beta1/' @@ -322,10 +322,10 @@ def instrument_mlkit_service(status=200, payload=None, operations=False, app=Non adapter = testutils.MockAdapter if operations: - mlkit_service._operation_client.session.mount( + ml_service._operation_client.session.mount( session_url, adapter(payload, status, recorder)) else: - mlkit_service._client.session.mount( + ml_service._client.session.mount( session_url, adapter(payload, status, recorder)) return recorder @@ -333,23 +333,23 @@ class _TestStorageClient(object): @staticmethod def upload(bucket_name, model_file_name, app): del app # unused variable - blob_name = mlkit._CloudStorageClient.BLOB_NAME.format(model_file_name) - return mlkit._CloudStorageClient.GCS_URI.format(bucket_name, blob_name) + blob_name = ml._CloudStorageClient.BLOB_NAME.format(model_file_name) + return ml._CloudStorageClient.GCS_URI.format(bucket_name, blob_name) @staticmethod def sign_uri(gcs_tflite_uri, app): del app # unused variable - bucket_name, blob_name = mlkit._CloudStorageClient._parse_gcs_tflite_uri(gcs_tflite_uri) + bucket_name, blob_name = ml._CloudStorageClient._parse_gcs_tflite_uri(gcs_tflite_uri) return GCS_TFLITE_SIGNED_URI_PATTERN.format(bucket_name, blob_name) class TestModel(object): - """Tests mlkit.Model class.""" + """Tests ml.Model class.""" @classmethod def setup_class(cls): cred = testutils.MockCredential() firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) - mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test - mlkit.TFLiteGCSModelSource._STORAGE_CLIENT = _TestStorageClient() + ml._MLService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test + ml.TFLiteGCSModelSource._STORAGE_CLIENT = _TestStorageClient() @classmethod def teardown_class(cls): @@ -361,7 +361,7 @@ def _op_url(project_id, model_id): 'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id) def test_model_success_err_state_lro(self): - model = mlkit.Model.from_dict(FULL_MODEL_ERR_STATE_LRO_JSON) + model = ml.Model.from_dict(FULL_MODEL_ERR_STATE_LRO_JSON) assert model.model_id == MODEL_ID_1 assert model.display_name == DISPLAY_NAME_1 assert model.create_time == CREATE_TIME_DATETIME @@ -376,7 +376,7 @@ def test_model_success_err_state_lro(self): assert model.as_dict() == FULL_MODEL_ERR_STATE_LRO_JSON def test_model_success_published(self): - model = mlkit.Model.from_dict(FULL_MODEL_PUBLISHED_JSON) + model = ml.Model.from_dict(FULL_MODEL_PUBLISHED_JSON) assert model.model_id == MODEL_ID_1 assert model.display_name == DISPLAY_NAME_1 assert model.create_time == CREATE_TIME_DATETIME @@ -391,7 +391,7 @@ def test_model_success_published(self): assert model.as_dict() == FULL_MODEL_PUBLISHED_JSON def test_model_keyword_based_creation_and_setters(self): - model = mlkit.Model(display_name=DISPLAY_NAME_1, tags=TAGS, model_format=TFLITE_FORMAT) + model = ml.Model(display_name=DISPLAY_NAME_1, tags=TAGS, model_format=TFLITE_FORMAT) assert model.display_name == DISPLAY_NAME_1 assert model.tags == TAGS assert model.model_format == TFLITE_FORMAT @@ -411,9 +411,9 @@ def test_model_keyword_based_creation_and_setters(self): } def test_model_format_source_creation(self): - model_source = mlkit.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) - model_format = mlkit.TFLiteFormat(model_source=model_source) - model = mlkit.Model(display_name=DISPLAY_NAME_1, model_format=model_format) + model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) + model_format = ml.TFLiteFormat(model_source=model_source) + model = ml.Model(display_name=DISPLAY_NAME_1, model_format=model_format) assert model.as_dict() == { 'displayName': DISPLAY_NAME_1, 'tfliteModel': { @@ -422,20 +422,20 @@ def test_model_format_source_creation(self): } def test_source_creation_from_tflite_file(self): - model_source = mlkit.TFLiteGCSModelSource.from_tflite_model_file( + model_source = ml.TFLiteGCSModelSource.from_tflite_model_file( "my_model.tflite", "my_bucket") assert model_source.as_dict() == { - 'gcsTfliteUri': 'gs://my_bucket/Firebase/MLKit/Models/my_model.tflite' + 'gcsTfliteUri': 'gs://my_bucket/Firebase/ML/Models/my_model.tflite' } def test_model_source_setters(self): - model_source = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI) + model_source = ml.TFLiteGCSModelSource(GCS_TFLITE_URI) model_source.gcs_tflite_uri = GCS_TFLITE_URI_2 assert model_source.gcs_tflite_uri == GCS_TFLITE_URI_2 assert model_source.as_dict() == GCS_TFLITE_URI_JSON_2 def test_model_format_setters(self): - model_format = mlkit.TFLiteFormat(model_source=GCS_TFLITE_MODEL_SOURCE) + model_format = ml.TFLiteFormat(model_source=GCS_TFLITE_MODEL_SOURCE) model_format.model_source = GCS_TFLITE_MODEL_SOURCE_2 assert model_format.model_source == GCS_TFLITE_MODEL_SOURCE_2 assert model_format.as_dict() == { @@ -445,9 +445,9 @@ def test_model_format_setters(self): } def test_model_as_dict_for_upload(self): - model_source = mlkit.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) - model_format = mlkit.TFLiteFormat(model_source=model_source) - model = mlkit.Model(display_name=DISPLAY_NAME_1, model_format=model_format) + model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) + model_format = ml.TFLiteFormat(model_source=model_source) + model = ml.Model(display_name=DISPLAY_NAME_1, model_format=model_format) assert model.as_dict(for_upload=True) == { 'displayName': DISPLAY_NAME_1, 'tfliteModel': { @@ -456,11 +456,11 @@ def test_model_as_dict_for_upload(self): } @pytest.mark.parametrize('helper_func', [ - mlkit.TFLiteGCSModelSource.from_keras_model, - mlkit.TFLiteGCSModelSource.from_saved_model + ml.TFLiteGCSModelSource.from_keras_model, + ml.TFLiteGCSModelSource.from_saved_model ]) def test_tf_not_enabled(self, helper_func): - mlkit._TF_ENABLED = False # for reliability + ml._TF_ENABLED = False # for reliability with pytest.raises(ImportError) as excinfo: helper_func(None) check_error(excinfo, ImportError) @@ -472,7 +472,7 @@ def test_tf_not_enabled(self, helper_func): ]) def test_model_display_name_validation_errors(self, display_name, exc_type): with pytest.raises(exc_type) as excinfo: - mlkit.Model(display_name=display_name) + ml.Model(display_name=display_name) check_error(excinfo, exc_type) @pytest.mark.parametrize('tags, exc_type, error_message', [ @@ -486,7 +486,7 @@ def test_model_display_name_validation_errors(self, display_name, exc_type): ]) def test_model_tags_validation_errors(self, tags, exc_type, error_message): with pytest.raises(exc_type) as excinfo: - mlkit.Model(tags=tags) + ml.Model(tags=tags) check_error(excinfo, exc_type, error_message) @pytest.mark.parametrize('model_format', [ @@ -498,7 +498,7 @@ def test_model_tags_validation_errors(self, tags, exc_type, error_message): ]) def test_model_format_validation_errors(self, model_format): with pytest.raises(TypeError) as excinfo: - mlkit.Model(model_format=model_format) + ml.Model(model_format=model_format) check_error(excinfo, TypeError, 'Model format must be a ModelFormat object.') @pytest.mark.parametrize('model_source', [ @@ -510,7 +510,7 @@ def test_model_format_validation_errors(self, model_format): ]) def test_model_source_validation_errors(self, model_source): with pytest.raises(TypeError) as excinfo: - mlkit.TFLiteFormat(model_source=model_source) + ml.TFLiteFormat(model_source=model_source) check_error(excinfo, TypeError, 'Model source must be a TFLiteModelSource object.') @pytest.mark.parametrize('uri, exc_type', [ @@ -526,18 +526,18 @@ def test_model_source_validation_errors(self, model_source): ]) def test_gcs_tflite_source_validation_errors(self, uri, exc_type): with pytest.raises(exc_type) as excinfo: - mlkit.TFLiteGCSModelSource(gcs_tflite_uri=uri) + ml.TFLiteGCSModelSource(gcs_tflite_uri=uri) check_error(excinfo, exc_type) def test_wait_for_unlocked_not_locked(self): - model = mlkit.Model(display_name="not_locked") + model = ml.Model(display_name="not_locked") model.wait_for_unlocked() def test_wait_for_unlocked(self): - recorder = instrument_mlkit_service(status=200, - operations=True, - payload=OPERATION_DONE_PUBLISHED_RESPONSE) - model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_1) + recorder = instrument_ml_service(status=200, + operations=True, + payload=OPERATION_DONE_PUBLISHED_RESPONSE) + model = ml.Model.from_dict(LOCKED_MODEL_JSON_1) model.wait_for_unlocked() assert model == FULL_MODEL_PUBLISHED assert len(recorder) == 1 @@ -545,10 +545,10 @@ def test_wait_for_unlocked(self): assert recorder[0].url == TestModel._op_url(PROJECT_ID, MODEL_ID_1) def test_wait_for_unlocked_timeout(self): - recorder = instrument_mlkit_service( + recorder = instrument_ml_service( status=200, operations=True, payload=OPERATION_NOT_DONE_RESPONSE) - mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 3 # longer so timeout applies immediately - model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_1) + ml._MLService.POLL_BASE_WAIT_TIME_SECONDS = 3 # longer so timeout applies immediately + model = ml.Model.from_dict(LOCKED_MODEL_JSON_1) with pytest.raises(Exception) as excinfo: model.wait_for_unlocked(max_time_seconds=0.1) check_error(excinfo, exceptions.DeadlineExceededError, 'Polling max time exceeded.') @@ -556,12 +556,12 @@ def test_wait_for_unlocked_timeout(self): class TestCreateModel(object): - """Tests mlkit.create_model.""" + """Tests ml.create_model.""" @classmethod def setup_class(cls): cred = testutils.MockCredential() firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) - mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test + ml._MLService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test @classmethod def teardown_class(cls): @@ -581,16 +581,16 @@ def _get_url(project_id, model_id): return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) def test_immediate_done(self): - instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE) - model = mlkit.create_model(MODEL_1) + instrument_ml_service(status=200, payload=OPERATION_DONE_RESPONSE) + model = ml.create_model(MODEL_1) assert model == CREATED_UPDATED_MODEL_1 def test_returns_locked(self): - recorder = instrument_mlkit_service( + recorder = instrument_ml_service( status=[200, 200], payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE]) - expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) - model = mlkit.create_model(MODEL_1) + expected_model = ml.Model.from_dict(LOCKED_MODEL_JSON_2) + model = ml.create_model(MODEL_1) assert model == expected_model assert len(recorder) == 2 @@ -600,23 +600,23 @@ def test_returns_locked(self): assert recorder[1].url == TestCreateModel._get_url(PROJECT_ID, MODEL_ID_1) def test_operation_error(self): - instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE) + instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE) with pytest.raises(Exception) as excinfo: - mlkit.create_model(MODEL_1) + ml.create_model(MODEL_1) # The http request succeeded, the operation returned contains a create failure check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) def test_malformed_operation(self): - instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) + instrument_ml_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) with pytest.raises(Exception) as excinfo: - mlkit.create_model(MODEL_1) + ml.create_model(MODEL_1) check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.') def test_rpc_error_create(self): - create_recorder = instrument_mlkit_service( + create_recorder = instrument_ml_service( status=400, payload=ERROR_RESPONSE_BAD_REQUEST) with pytest.raises(Exception) as excinfo: - mlkit.create_model(MODEL_1) + ml.create_model(MODEL_1) check_firebase_error( excinfo, ERROR_STATUS_BAD_REQUEST, @@ -628,36 +628,36 @@ def test_rpc_error_create(self): @pytest.mark.parametrize('model', INVALID_MODEL_ARGS) def test_not_model(self, model): with pytest.raises(Exception) as excinfo: - mlkit.create_model(model) - check_error(excinfo, TypeError, 'Model must be an mlkit.Model.') + ml.create_model(model) + check_error(excinfo, TypeError, 'Model must be an ml.Model.') def test_missing_display_name(self): with pytest.raises(Exception) as excinfo: - mlkit.create_model(mlkit.Model.from_dict({})) + ml.create_model(ml.Model.from_dict({})) check_error(excinfo, ValueError, 'Model must have a display name.') def test_missing_op_name(self): - instrument_mlkit_service(status=200, payload=OPERATION_MISSING_NAME_RESPONSE) + instrument_ml_service(status=200, payload=OPERATION_MISSING_NAME_RESPONSE) with pytest.raises(Exception) as excinfo: - mlkit.create_model(MODEL_1) + ml.create_model(MODEL_1) check_error(excinfo, TypeError) @pytest.mark.parametrize('op_name', INVALID_OP_NAME_ARGS) def test_invalid_op_name(self, op_name): payload = json.dumps({'name': op_name}) - instrument_mlkit_service(status=200, payload=payload) + instrument_ml_service(status=200, payload=payload) with pytest.raises(Exception) as excinfo: - mlkit.create_model(MODEL_1) + ml.create_model(MODEL_1) check_error(excinfo, ValueError, 'Operation name format is invalid.') class TestUpdateModel(object): - """Tests mlkit.update_model.""" + """Tests ml.update_model.""" @classmethod def setup_class(cls): cred = testutils.MockCredential() firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) - mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test + ml._MLService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test @classmethod def teardown_class(cls): @@ -673,16 +673,16 @@ def _op_url(project_id, model_id): 'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id) def test_immediate_done(self): - instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE) - model = mlkit.update_model(MODEL_1) + instrument_ml_service(status=200, payload=OPERATION_DONE_RESPONSE) + model = ml.update_model(MODEL_1) assert model == CREATED_UPDATED_MODEL_1 def test_returns_locked(self): - recorder = instrument_mlkit_service( + recorder = instrument_ml_service( status=[200, 200], payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE]) - expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) - model = mlkit.update_model(MODEL_1) + expected_model = ml.Model.from_dict(LOCKED_MODEL_JSON_2) + model = ml.update_model(MODEL_1) assert model == expected_model assert len(recorder) == 2 @@ -692,23 +692,23 @@ def test_returns_locked(self): assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) def test_operation_error(self): - instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE) + instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE) with pytest.raises(Exception) as excinfo: - mlkit.update_model(MODEL_1) + ml.update_model(MODEL_1) # The http request succeeded, the operation returned contains an update failure check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) def test_malformed_operation(self): - instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) + instrument_ml_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) with pytest.raises(Exception) as excinfo: - mlkit.update_model(MODEL_1) + ml.update_model(MODEL_1) check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.') def test_rpc_error(self): - create_recorder = instrument_mlkit_service( + create_recorder = instrument_ml_service( status=400, payload=ERROR_RESPONSE_BAD_REQUEST) with pytest.raises(Exception) as excinfo: - mlkit.update_model(MODEL_1) + ml.update_model(MODEL_1) check_firebase_error( excinfo, ERROR_STATUS_BAD_REQUEST, @@ -720,35 +720,35 @@ def test_rpc_error(self): @pytest.mark.parametrize('model', INVALID_MODEL_ARGS) def test_not_model(self, model): with pytest.raises(Exception) as excinfo: - mlkit.update_model(model) - check_error(excinfo, TypeError, 'Model must be an mlkit.Model.') + ml.update_model(model) + check_error(excinfo, TypeError, 'Model must be an ml.Model.') def test_missing_display_name(self): with pytest.raises(Exception) as excinfo: - mlkit.update_model(mlkit.Model.from_dict({})) + ml.update_model(ml.Model.from_dict({})) check_error(excinfo, ValueError, 'Model must have a display name.') def test_missing_op_name(self): - instrument_mlkit_service(status=200, payload=OPERATION_MISSING_NAME_RESPONSE) + instrument_ml_service(status=200, payload=OPERATION_MISSING_NAME_RESPONSE) with pytest.raises(Exception) as excinfo: - mlkit.update_model(MODEL_1) + ml.update_model(MODEL_1) check_error(excinfo, TypeError) @pytest.mark.parametrize('op_name', INVALID_OP_NAME_ARGS) def test_invalid_op_name(self, op_name): payload = json.dumps({'name': op_name}) - instrument_mlkit_service(status=200, payload=payload) + instrument_ml_service(status=200, payload=payload) with pytest.raises(Exception) as excinfo: - mlkit.update_model(MODEL_1) + ml.update_model(MODEL_1) check_error(excinfo, ValueError, 'Operation name format is invalid.') class TestPublishUnpublish(object): - """Tests mlkit.publish_model and mlkit.unpublish_model.""" + """Tests ml.publish_model and ml.unpublish_model.""" PUBLISH_UNPUBLISH_WITH_ARGS = [ - (mlkit.publish_model, True), - (mlkit.unpublish_model, False) + (ml.publish_model, True), + (ml.unpublish_model, False) ] PUBLISH_UNPUBLISH_FUNCS = [item[0] for item in PUBLISH_UNPUBLISH_WITH_ARGS] @@ -756,7 +756,7 @@ class TestPublishUnpublish(object): def setup_class(cls): cred = testutils.MockCredential() firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) - mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test + ml._MLService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test @classmethod def teardown_class(cls): @@ -779,7 +779,7 @@ def _op_url(project_id, model_id): @pytest.mark.parametrize('publish_function, published', PUBLISH_UNPUBLISH_WITH_ARGS) def test_immediate_done(self, publish_function, published): - recorder = instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE) + recorder = instrument_ml_service(status=200, payload=OPERATION_DONE_RESPONSE) model = publish_function(MODEL_ID_1) assert model == CREATED_UPDATED_MODEL_1 assert len(recorder) == 1 @@ -790,10 +790,10 @@ def test_immediate_done(self, publish_function, published): @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_returns_locked(self, publish_function): - recorder = instrument_mlkit_service( + recorder = instrument_ml_service( status=[200, 200], payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE]) - expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) + expected_model = ml.Model.from_dict(LOCKED_MODEL_JSON_2) model = publish_function(MODEL_ID_1) assert model == expected_model @@ -805,7 +805,7 @@ def test_returns_locked(self, publish_function): @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_operation_error(self, publish_function): - instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE) + instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE) with pytest.raises(Exception) as excinfo: publish_function(MODEL_ID_1) # The http request succeeded, the operation returned contains an update failure @@ -813,14 +813,14 @@ def test_operation_error(self, publish_function): @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_malformed_operation(self, publish_function): - instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) + instrument_ml_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) with pytest.raises(Exception) as excinfo: publish_function(MODEL_ID_1) check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.') @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_rpc_error(self, publish_function): - create_recorder = instrument_mlkit_service( + create_recorder = instrument_ml_service( status=400, payload=ERROR_RESPONSE_BAD_REQUEST) with pytest.raises(Exception) as excinfo: publish_function(MODEL_ID_1) @@ -834,7 +834,7 @@ def test_rpc_error(self, publish_function): class TestGetModel(object): - """Tests mlkit.get_model.""" + """Tests ml.get_model.""" @classmethod def setup_class(cls): cred = testutils.MockCredential() @@ -849,8 +849,8 @@ def _url(project_id, model_id): return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) def test_get_model(self): - recorder = instrument_mlkit_service(status=200, payload=DEFAULT_GET_RESPONSE) - model = mlkit.get_model(MODEL_ID_1) + recorder = instrument_ml_service(status=200, payload=DEFAULT_GET_RESPONSE) + model = ml.get_model(MODEL_ID_1) assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1) @@ -861,13 +861,13 @@ def test_get_model(self): @pytest.mark.parametrize('model_id, exc_type', INVALID_MODEL_ID_ARGS) def test_get_model_validation_errors(self, model_id, exc_type): with pytest.raises(exc_type) as excinfo: - mlkit.get_model(model_id) + ml.get_model(model_id) check_error(excinfo, exc_type) def test_get_model_error(self): - recorder = instrument_mlkit_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) + recorder = instrument_ml_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) with pytest.raises(exceptions.NotFoundError) as excinfo: - mlkit.get_model(MODEL_ID_1) + ml.get_model(MODEL_ID_1) check_firebase_error( excinfo, ERROR_STATUS_NOT_FOUND, @@ -882,12 +882,12 @@ def test_no_project_id(self): def evaluate(): app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') with pytest.raises(ValueError): - mlkit.get_model(MODEL_ID_1, app) + ml.get_model(MODEL_ID_1, app) testutils.run_without_project_id(evaluate) class TestDeleteModel(object): - """Tests mlkit.delete_model.""" + """Tests ml.delete_model.""" @classmethod def setup_class(cls): cred = testutils.MockCredential() @@ -902,8 +902,8 @@ def _url(project_id, model_id): return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) def test_delete_model(self): - recorder = instrument_mlkit_service(status=200, payload=EMPTY_RESPONSE) - mlkit.delete_model(MODEL_ID_1) # no response for delete + recorder = instrument_ml_service(status=200, payload=EMPTY_RESPONSE) + ml.delete_model(MODEL_ID_1) # no response for delete assert len(recorder) == 1 assert recorder[0].method == 'DELETE' assert recorder[0].url == TestDeleteModel._url(PROJECT_ID, MODEL_ID_1) @@ -911,13 +911,13 @@ def test_delete_model(self): @pytest.mark.parametrize('model_id, exc_type', INVALID_MODEL_ID_ARGS) def test_delete_model_validation_errors(self, model_id, exc_type): with pytest.raises(exc_type) as excinfo: - mlkit.delete_model(model_id) + ml.delete_model(model_id) check_error(excinfo, exc_type) def test_delete_model_error(self): - recorder = instrument_mlkit_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) + recorder = instrument_ml_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) with pytest.raises(exceptions.NotFoundError) as excinfo: - mlkit.delete_model(MODEL_ID_1) + ml.delete_model(MODEL_ID_1) check_firebase_error( excinfo, ERROR_STATUS_NOT_FOUND, @@ -932,12 +932,12 @@ def test_no_project_id(self): def evaluate(): app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') with pytest.raises(ValueError): - mlkit.delete_model(MODEL_ID_1, app) + ml.delete_model(MODEL_ID_1, app) testutils.run_without_project_id(evaluate) class TestListModels(object): - """Tests mlkit.list_models.""" + """Tests ml.list_models.""" @classmethod def setup_class(cls): cred = testutils.MockCredential() @@ -953,14 +953,14 @@ def _url(project_id): @staticmethod def _check_page(page, model_count): - assert isinstance(page, mlkit.ListModelsPage) + assert isinstance(page, ml.ListModelsPage) assert len(page.models) == model_count for model in page.models: - assert isinstance(model, mlkit.Model) + assert isinstance(model, ml.Model) def test_list_models_no_args(self): - recorder = instrument_mlkit_service(status=200, payload=DEFAULT_LIST_RESPONSE) - models_page = mlkit.list_models() + recorder = instrument_ml_service(status=200, payload=DEFAULT_LIST_RESPONSE) + models_page = ml.list_models() assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == TestListModels._url(PROJECT_ID) @@ -971,8 +971,8 @@ def test_list_models_no_args(self): assert models_page.models[1] == MODEL_2 def test_list_models_with_all_args(self): - recorder = instrument_mlkit_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) - models_page = mlkit.list_models( + recorder = instrument_ml_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) + models_page = ml.list_models( 'display_name=displayName3', page_size=10, page_token=PAGE_TOKEN) @@ -982,7 +982,7 @@ def test_list_models_with_all_args(self): TestListModels._url(PROJECT_ID) + '?filter=display_name%3DdisplayName3&page_size=10&page_token={0}' .format(PAGE_TOKEN)) - assert isinstance(models_page, mlkit.ListModelsPage) + assert isinstance(models_page, ml.ListModelsPage) assert len(models_page.models) == 1 assert models_page.models[0] == MODEL_3 assert not models_page.has_next_page @@ -990,7 +990,7 @@ def test_list_models_with_all_args(self): @pytest.mark.parametrize('list_filter', INVALID_STRING_OR_NONE_ARGS) def test_list_models_list_filter_validation(self, list_filter): with pytest.raises(TypeError) as excinfo: - mlkit.list_models(list_filter=list_filter) + ml.list_models(list_filter=list_filter) check_error(excinfo, TypeError, 'List filter must be a string or None.') @pytest.mark.parametrize('page_size, exc_type, error_message', [ @@ -1001,23 +1001,23 @@ def test_list_models_list_filter_validation(self, list_filter): (True, TypeError, 'Page size must be a number or None.'), (-1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), (0, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), - (mlkit._MAX_PAGE_SIZE + 1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG) + (ml._MAX_PAGE_SIZE + 1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG) ]) def test_list_models_page_size_validation(self, page_size, exc_type, error_message): with pytest.raises(exc_type) as excinfo: - mlkit.list_models(page_size=page_size) + ml.list_models(page_size=page_size) check_error(excinfo, exc_type, error_message) @pytest.mark.parametrize('page_token', INVALID_STRING_OR_NONE_ARGS) def test_list_models_page_token_validation(self, page_token): with pytest.raises(TypeError) as excinfo: - mlkit.list_models(page_token=page_token) + ml.list_models(page_token=page_token) check_error(excinfo, TypeError, 'Page token must be a string or None.') def test_list_models_error(self): - recorder = instrument_mlkit_service(status=400, payload=ERROR_RESPONSE_BAD_REQUEST) + recorder = instrument_ml_service(status=400, payload=ERROR_RESPONSE_BAD_REQUEST) with pytest.raises(exceptions.InvalidArgumentError) as excinfo: - mlkit.list_models() + ml.list_models() check_firebase_error( excinfo, ERROR_STATUS_BAD_REQUEST, @@ -1032,12 +1032,12 @@ def test_no_project_id(self): def evaluate(): app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') with pytest.raises(ValueError): - mlkit.list_models(app=app) + ml.list_models(app=app) testutils.run_without_project_id(evaluate) def test_list_single_page(self): - recorder = instrument_mlkit_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) - models_page = mlkit.list_models() + recorder = instrument_ml_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) + models_page = ml.list_models() assert len(recorder) == 1 assert models_page.next_page_token == '' assert models_page.has_next_page is False @@ -1047,15 +1047,15 @@ def test_list_single_page(self): def test_list_multiple_pages(self): # Page 1 - recorder = instrument_mlkit_service(status=200, payload=DEFAULT_LIST_RESPONSE) - page = mlkit.list_models() + recorder = instrument_ml_service(status=200, payload=DEFAULT_LIST_RESPONSE) + page = ml.list_models() assert len(recorder) == 1 assert len(page.models) == 2 assert page.next_page_token == NEXT_PAGE_TOKEN assert page.has_next_page is True # Page 2 - recorder = instrument_mlkit_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) + recorder = instrument_ml_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) page_2 = page.get_next_page() assert len(recorder) == 1 assert len(page_2.models) == 1 @@ -1065,8 +1065,8 @@ def test_list_multiple_pages(self): def test_list_models_paged_iteration(self): # Page 1 - recorder = instrument_mlkit_service(status=200, payload=DEFAULT_LIST_RESPONSE) - page = mlkit.list_models() + recorder = instrument_ml_service(status=200, payload=DEFAULT_LIST_RESPONSE) + page = ml.list_models() assert page.next_page_token == NEXT_PAGE_TOKEN assert page.has_next_page is True iterator = page.iterate_all() @@ -1076,15 +1076,15 @@ def test_list_models_paged_iteration(self): assert len(recorder) == 1 # Page 2 - recorder = instrument_mlkit_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) + recorder = instrument_ml_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) model = next(iterator) assert model.display_name == DISPLAY_NAME_3 with pytest.raises(StopIteration): next(iterator) def test_list_models_stop_iteration(self): - recorder = instrument_mlkit_service(status=200, payload=ONE_PAGE_LIST_RESPONSE) - page = mlkit.list_models() + recorder = instrument_ml_service(status=200, payload=ONE_PAGE_LIST_RESPONSE) + page = ml.list_models() assert len(recorder) == 1 assert len(page.models) == 3 iterator = page.iterate_all() @@ -1095,8 +1095,8 @@ def test_list_models_stop_iteration(self): assert len(models) == 3 def test_list_models_no_models(self): - recorder = instrument_mlkit_service(status=200, payload=NO_MODELS_LIST_RESPONSE) - page = mlkit.list_models() + recorder = instrument_ml_service(status=200, payload=NO_MODELS_LIST_RESPONSE) + page = ml.list_models() assert len(recorder) == 1 assert len(page.models) == 0 models = [model for model in page.iterate_all()] From a13d2a7e1af608893b1551770a87a76bfdd69b99 Mon Sep 17 00:00:00 2001 From: ifielker Date: Fri, 13 Dec 2019 16:21:50 -0500 Subject: [PATCH 27/37] Adding File naming capability to from_saved_model and from_keras_model. (#375) adding File naming capability for ModelSource --- firebase_admin/ml.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index 809ba9a41..c6720f081 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -524,11 +524,13 @@ def _tf_convert_from_keras_model(keras_model): return converter.convert() @classmethod - def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None): + def from_saved_model(cls, saved_model_dir, model_file_name='firebase_ml_model.tflite', + bucket_name=None, app=None): """Creates a Tensor Flow Lite model from the saved model, and uploads the model to GCS. Args: saved_model_dir: The saved model directory. + model_file_name: The name that the tflite model will be saved as in Cloud Storage. bucket_name: The name of an existing bucket. None to use the default bucket configured in the app. app: Optional. A Firebase app instance (or None to use the default app) @@ -541,16 +543,18 @@ def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None): """ TFLiteGCSModelSource._assert_tf_enabled() tflite_model = TFLiteGCSModelSource._tf_convert_from_saved_model(saved_model_dir) - open('firebase_ml_model.tflite', 'wb').write(tflite_model) - return TFLiteGCSModelSource.from_tflite_model_file( - 'firebase_ml_model.tflite', bucket_name, app) + with open(model_file_name, 'wb') as model_file: + model_file.write(tflite_model) + return TFLiteGCSModelSource.from_tflite_model_file(model_file_name, bucket_name, app) @classmethod - def from_keras_model(cls, keras_model, bucket_name=None, app=None): + def from_keras_model(cls, keras_model, model_file_name='firebase_ml_model.tflite', + bucket_name=None, app=None): """Creates a Tensor Flow Lite model from the keras model, and uploads the model to GCS. Args: keras_model: A tf.keras model. + model_file_name: The name that the tflite model will be saved as in Cloud Storage. bucket_name: The name of an existing bucket. None to use the default bucket configured in the app. app: Optional. A Firebase app instance (or None to use the default app) @@ -563,9 +567,9 @@ def from_keras_model(cls, keras_model, bucket_name=None, app=None): """ TFLiteGCSModelSource._assert_tf_enabled() tflite_model = TFLiteGCSModelSource._tf_convert_from_keras_model(keras_model) - open('firebase_ml_model.tflite', 'wb').write(tflite_model) - return TFLiteGCSModelSource.from_tflite_model_file( - 'firebase_ml_model.tflite', bucket_name, app) + with open(model_file_name, 'wb') as model_file: + model_file.write(tflite_model) + return TFLiteGCSModelSource.from_tflite_model_file(model_file_name, bucket_name, app) @property def gcs_tflite_uri(self): From b133bbb345ce7607f282218564d02a07e7baef45 Mon Sep 17 00:00:00 2001 From: ifielker Date: Thu, 23 Jan 2020 17:03:59 -0500 Subject: [PATCH 28/37] Firebase ML Modify Operation Handling Code to match rpc codes not html codes (#390) * Firebase ML Modify Operation Handling Code to match actual codes * apply database fix too --- firebase_admin/_utils.py | 26 ++++++++++++++++++++++++-- requirements.txt | 2 +- setup.py | 2 +- tests/test_db.py | 2 +- tests/test_messaging.py | 4 ++-- tests/test_ml.py | 2 +- 6 files changed, 30 insertions(+), 8 deletions(-) diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index fb6e32932..35a1b7467 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -58,6 +58,25 @@ 503: exceptions.UNAVAILABLE, } +# See https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto +_RPC_CODE_TO_ERROR_CODE = { + 1: exceptions.CANCELLED, + 2: exceptions.UNKNOWN, + 3: exceptions.INVALID_ARGUMENT, + 4: exceptions.DEADLINE_EXCEEDED, + 5: exceptions.NOT_FOUND, + 6: exceptions.ALREADY_EXISTS, + 7: exceptions.PERMISSION_DENIED, + 8: exceptions.RESOURCE_EXHAUSTED, + 9: exceptions.FAILED_PRECONDITION, + 10: exceptions.ABORTED, + 11: exceptions.OUT_OF_RANGE, + 13: exceptions.INTERNAL, + 14: exceptions.UNAVAILABLE, + 15: exceptions.DATA_LOSS, + 16: exceptions.UNAUTHENTICATED, +} + def _get_initialized_app(app): if app is None: @@ -120,9 +139,9 @@ def handle_operation_error(error): message='Unknown error while making a remote service call: {0}'.format(error), cause=error) - status_code = error.get('code') + rpc_code = error.get('code') message = error.get('message') - error_code = _http_status_to_error_code(status_code) + error_code = _rpc_code_to_error_code(rpc_code) err_type = _error_code_to_exception_type(error_code) return err_type(message=message) @@ -283,6 +302,9 @@ def _http_status_to_error_code(status): """Maps an HTTP status to a platform error code.""" return _HTTP_STATUS_TO_ERROR_CODE.get(status, exceptions.UNKNOWN) +def _rpc_code_to_error_code(rpc_code): + """Maps an RPC code to a platform error code.""" + return _RPC_CODE_TO_ERROR_CODE.get(rpc_code, exceptions.UNKNOWN) def _error_code_to_exception_type(code): """Maps a platform error code to an exception type.""" diff --git a/requirements.txt b/requirements.txt index 7a8d855bd..c89318a67 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ pytest-cov >= 2.4.0 pytest-localserver >= 0.4.1 tox >= 3.6.0 -cachecontrol >= 0.12.4 +cachecontrol >= 0.12.6 google-api-core[grpc] >= 1.7.0, < 2.0.0dev; platform.python_implementation != 'PyPy' google-api-python-client >= 1.7.8 google-cloud-firestore >= 0.31.0; platform.python_implementation != 'PyPy' diff --git a/setup.py b/setup.py index 15ae97f93..c463108d0 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ long_description = ('The Firebase Admin Python SDK enables server-side (backend) Python developers ' 'to integrate Firebase into their services and applications.') install_requires = [ - 'cachecontrol>=0.12.4', + 'cachecontrol>=0.12.6', 'google-api-core[grpc] >= 1.7.0, < 2.0.0dev; platform.python_implementation != "PyPy"', 'google-api-python-client >= 1.7.8', 'google-cloud-firestore>=0.31.0; platform.python_implementation != "PyPy"', diff --git a/tests/test_db.py b/tests/test_db.py index 081c31e3d..e9f8f7dda 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -819,7 +819,7 @@ def test_http_timeout(self): assert ref._client.timeout == 60 assert ref.get() == {} assert len(recorder) == 1 - assert recorder[0]._extra_kwargs['timeout'] == 60 + assert recorder[0]._extra_kwargs['timeout'] == pytest.approx(60, 0.001) def test_app_delete(self): app = firebase_admin.initialize_app( diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 40f6baada..5d048d613 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -1254,7 +1254,7 @@ def test_send(self): msg = messaging.Message(topic='foo') messaging.send(msg) assert len(self.recorder) == 1 - assert self.recorder[0]._extra_kwargs['timeout'] == 4 + assert self.recorder[0]._extra_kwargs['timeout'] == pytest.approx(4, 0.001) def test_topic_management_timeout(self): self.fcm_service._client.session.mount( @@ -1266,7 +1266,7 @@ def test_topic_management_timeout(self): ) messaging.subscribe_to_topic(['1'], 'a') assert len(self.recorder) == 1 - assert self.recorder[0]._extra_kwargs['timeout'] == 4 + assert self.recorder[0]._extra_kwargs['timeout'] == pytest.approx(4, 0.001) class TestSend(object): diff --git a/tests/test_ml.py b/tests/test_ml.py index e66507e88..18aa789f8 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -169,7 +169,7 @@ # Name is required if the operation is not done. 'done': False } -OPERATION_ERROR_CODE = 400 +OPERATION_ERROR_CODE = 3 OPERATION_ERROR_MSG = "Invalid argument" OPERATION_ERROR_EXPECTED_STATUS = 'INVALID_ARGUMENT' OPERATION_ERROR_JSON_1 = { From 877044816e81b7bc57bafec502ebb2ac6fac58cf Mon Sep 17 00:00:00 2001 From: ifielker Date: Mon, 27 Jan 2020 15:21:52 -0500 Subject: [PATCH 29/37] Mlkit fix date handling2 (#391) * Fix create/update date handling * Skip unrelated failing tests (until sync) --- firebase_admin/_utils.py | 2 +- firebase_admin/ml.py | 23 +++++++++-------- tests/test_db.py | 1 + tests/test_ml.py | 56 +++++++++++++++++----------------------- 4 files changed, 37 insertions(+), 45 deletions(-) diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index 35a1b7467..495632ad2 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -58,6 +58,7 @@ 503: exceptions.UNAVAILABLE, } + # See https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto _RPC_CODE_TO_ERROR_CODE = { 1: exceptions.CANCELLED, @@ -77,7 +78,6 @@ 16: exceptions.UNAUTHENTICATED, } - def _get_initialized_app(app): if app is None: return firebase_admin.get_app() diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index c6720f081..d34bdd581 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -20,7 +20,6 @@ import datetime -import numbers import re import time import requests @@ -243,23 +242,25 @@ def display_name(self, display_name): self._data['displayName'] = _validate_display_name(display_name) return self + @staticmethod + def _convert_to_millis(date_string): + if not date_string: + return None + format_str = '%Y-%m-%dT%H:%M:%S.%fZ' + epoch = datetime.datetime.utcfromtimestamp(0) + datetime_object = datetime.datetime.strptime(date_string, format_str) + millis = int((datetime_object - epoch).total_seconds() * 1000) + return millis + @property def create_time(self): """The time the model was created.""" - seconds = self._data.get('createTime', {}).get('seconds') - if not isinstance(seconds, numbers.Number): - return None - - return datetime.datetime.fromtimestamp(float(seconds)) + return Model._convert_to_millis(self._data.get('createTime', None)) @property def update_time(self): """The time the model was last updated.""" - seconds = self._data.get('updateTime', {}).get('seconds') - if not isinstance(seconds, numbers.Number): - return None - - return datetime.datetime.fromtimestamp(float(seconds)) + return Model._convert_to_millis(self._data.get('updateTime', None)) @property def validation_error(self): diff --git a/tests/test_db.py b/tests/test_db.py index e9f8f7dda..a6e7408ab 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -728,6 +728,7 @@ def test_parse_db_url_errors(self, url, emulator_host): @pytest.mark.parametrize('url', [ 'https://test.firebaseio.com', 'https://test.firebaseio.com/' ]) + @pytest.mark.skip(reason='only skip until mlkit branch is synced with master') def test_valid_db_url(self, url): firebase_admin.initialize_app(testutils.MockCredential(), {'databaseURL' : url}) ref = db.reference() diff --git a/tests/test_ml.py b/tests/test_ml.py index 18aa789f8..e91517dd5 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -14,7 +14,6 @@ """Test cases for the firebase_admin.ml module.""" -import datetime import json import pytest @@ -27,25 +26,16 @@ PROJECT_ID = 'myProject1' PAGE_TOKEN = 'pageToken' NEXT_PAGE_TOKEN = 'nextPageToken' -CREATE_TIME_SECONDS = 1566426374 -CREATE_TIME_SECONDS_2 = 1566426385 -CREATE_TIME_JSON = { - 'seconds': CREATE_TIME_SECONDS -} -CREATE_TIME_DATETIME = datetime.datetime.fromtimestamp(CREATE_TIME_SECONDS) -CREATE_TIME_JSON_2 = { - 'seconds': CREATE_TIME_SECONDS_2 -} -UPDATE_TIME_SECONDS = 1566426678 -UPDATE_TIME_SECONDS_2 = 1566426691 -UPDATE_TIME_JSON = { - 'seconds': UPDATE_TIME_SECONDS -} -UPDATE_TIME_DATETIME = datetime.datetime.fromtimestamp(UPDATE_TIME_SECONDS) -UPDATE_TIME_JSON_2 = { - 'seconds': UPDATE_TIME_SECONDS_2 -} +CREATE_TIME = '2020-01-21T20:44:27.392932Z' +CREATE_TIME_MILLIS = 1579639467392 + +UPDATE_TIME = '2020-01-21T22:45:29.392932Z' +UPDATE_TIME_MILLIS = 1579646729392 + +CREATE_TIME_2 = '2020-01-21T21:44:27.392932Z' +UPDATE_TIME_2 = '2020-01-21T23:45:29.392932Z' + ETAG = '33a64df551425fcc55e4d42a148795d9f25f89d4' MODEL_HASH = '987987a98b98798d098098e09809fc0893897' TAG_1 = 'Tag1' @@ -130,8 +120,8 @@ CREATED_UPDATED_MODEL_JSON_1 = { 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_1, - 'createTime': CREATE_TIME_JSON, - 'updateTime': UPDATE_TIME_JSON, + 'createTime': CREATE_TIME, + 'updateTime': UPDATE_TIME, 'state': MODEL_STATE_ERROR_JSON, 'etag': ETAG, 'modelHash': MODEL_HASH, @@ -142,8 +132,8 @@ LOCKED_MODEL_JSON_1 = { 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_1, - 'createTime': CREATE_TIME_JSON, - 'updateTime': UPDATE_TIME_JSON, + 'createTime': CREATE_TIME, + 'updateTime': UPDATE_TIME, 'tags': TAGS, 'activeOperations': [OPERATION_NOT_DONE_JSON_1] } @@ -151,8 +141,8 @@ LOCKED_MODEL_JSON_2 = { 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_2, - 'createTime': CREATE_TIME_JSON_2, - 'updateTime': UPDATE_TIME_JSON_2, + 'createTime': CREATE_TIME_2, + 'updateTime': UPDATE_TIME_2, 'tags': TAGS_2, 'activeOperations': [OPERATION_NOT_DONE_JSON_1] } @@ -183,8 +173,8 @@ FULL_MODEL_ERR_STATE_LRO_JSON = { 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_1, - 'createTime': CREATE_TIME_JSON, - 'updateTime': UPDATE_TIME_JSON, + 'createTime': CREATE_TIME, + 'updateTime': UPDATE_TIME, 'state': MODEL_STATE_ERROR_JSON, 'etag': ETAG, 'modelHash': MODEL_HASH, @@ -194,8 +184,8 @@ FULL_MODEL_PUBLISHED_JSON = { 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_1, - 'createTime': CREATE_TIME_JSON, - 'updateTime': UPDATE_TIME_JSON, + 'createTime': CREATE_TIME, + 'updateTime': UPDATE_TIME, 'state': MODEL_STATE_PUBLISHED_JSON, 'etag': ETAG, 'modelHash': MODEL_HASH, @@ -364,8 +354,8 @@ def test_model_success_err_state_lro(self): model = ml.Model.from_dict(FULL_MODEL_ERR_STATE_LRO_JSON) assert model.model_id == MODEL_ID_1 assert model.display_name == DISPLAY_NAME_1 - assert model.create_time == CREATE_TIME_DATETIME - assert model.update_time == UPDATE_TIME_DATETIME + assert model.create_time == CREATE_TIME_MILLIS + assert model.update_time == UPDATE_TIME_MILLIS assert model.validation_error == VALIDATION_ERROR_MSG assert model.published is False assert model.etag == ETAG @@ -379,8 +369,8 @@ def test_model_success_published(self): model = ml.Model.from_dict(FULL_MODEL_PUBLISHED_JSON) assert model.model_id == MODEL_ID_1 assert model.display_name == DISPLAY_NAME_1 - assert model.create_time == CREATE_TIME_DATETIME - assert model.update_time == UPDATE_TIME_DATETIME + assert model.create_time == CREATE_TIME_MILLIS + assert model.update_time == UPDATE_TIME_MILLIS assert model.validation_error is None assert model.published is True assert model.etag == ETAG From cf748c893cb33c1e9cf1b6655f55a9b93544848d Mon Sep 17 00:00:00 2001 From: ifielker Date: Mon, 27 Jan 2020 17:47:46 -0500 Subject: [PATCH 30/37] Firebase Ml Fix upload file naming (#392) * Fix File Naming --- firebase_admin/_utils.py | 1 + firebase_admin/ml.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index 495632ad2..e1eb83d90 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -78,6 +78,7 @@ 16: exceptions.UNAUTHENTICATED, } + def _get_initialized_app(app): if app is None: return firebase_admin.get_app() diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index d34bdd581..fa604933c 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -22,6 +22,7 @@ import datetime import re import time +import os import requests import six @@ -440,8 +441,10 @@ def _parse_gcs_tflite_uri(uri): def upload(bucket_name, model_file_name, app): """Upload a model file to the specified Storage bucket.""" _CloudStorageClient._assert_gcs_enabled() + + file_name = os.path.basename(model_file_name) bucket = storage.bucket(bucket_name, app=app) - blob_name = _CloudStorageClient.BLOB_NAME.format(model_file_name) + blob_name = _CloudStorageClient.BLOB_NAME.format(file_name) blob = bucket.blob(blob_name) blob.upload_from_filename(model_file_name) return _CloudStorageClient.GCS_URI.format(bucket.name, blob_name) From 0b706876df564b9fb49aae592d7d7ec1ae452454 Mon Sep 17 00:00:00 2001 From: ifielker Date: Thu, 30 Jan 2020 14:48:05 -0500 Subject: [PATCH 31/37] Integration tests for Firebase ML (#394) * Integration tests for Firebase ML --- integration/test_ml.py | 373 ++++++++++++++++++++++++++++++++ tests/data/invalid_model.tflite | 1 + tests/data/model1.tflite | Bin 0 -> 736 bytes 3 files changed, 374 insertions(+) create mode 100644 integration/test_ml.py create mode 100644 tests/data/invalid_model.tflite create mode 100644 tests/data/model1.tflite diff --git a/integration/test_ml.py b/integration/test_ml.py new file mode 100644 index 000000000..4c44f3d10 --- /dev/null +++ b/integration/test_ml.py @@ -0,0 +1,373 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for firebase_admin.ml module.""" +import os +import random +import re +import shutil +import string +import tempfile +import pytest + + +from firebase_admin import exceptions +from firebase_admin import ml +from tests import testutils + + +# pylint: disable=import-error,no-name-in-module +try: + import tensorflow as tf + _TF_ENABLED = True +except ImportError: + _TF_ENABLED = False + + +def _random_identifier(prefix): + #pylint: disable=unused-variable + suffix = ''.join([random.choice(string.ascii_letters + string.digits) for n in range(8)]) + return '{0}_{1}'.format(prefix, suffix) + + +NAME_ONLY_ARGS = { + 'display_name': _random_identifier('TestModel123_') +} +NAME_ONLY_ARGS_UPDATED = { + 'display_name': _random_identifier('TestModel123_updated_') +} +NAME_AND_TAGS_ARGS = { + 'display_name': _random_identifier('TestModel123_tags_'), + 'tags': ['test_tag123'] +} +FULL_MODEL_ARGS = { + 'display_name': _random_identifier('TestModel123_full_'), + 'tags': ['test_tag567'], + 'file_name': 'model1.tflite' +} +INVALID_FULL_MODEL_ARGS = { + 'display_name': _random_identifier('TestModel123_invalid_full_'), + 'tags': ['test_tag890'], + 'file_name': 'invalid_model.tflite' +} + + +@pytest.fixture +def firebase_model(request): + args = request.param + tflite_format = None + file_name = args.get('file_name') + if file_name: + file_path = testutils.resource_filename(file_name) + source = ml.TFLiteGCSModelSource.from_tflite_model_file(file_path) + tflite_format = ml.TFLiteFormat(model_source=source) + + ml_model = ml.Model( + display_name=args.get('display_name'), + tags=args.get('tags'), + model_format=tflite_format) + model = ml.create_model(model=ml_model) + yield model + _clean_up_model(model) + + +@pytest.fixture +def model_list(): + ml_model_1 = ml.Model(display_name=_random_identifier('TestModel123_list1_')) + model_1 = ml.create_model(model=ml_model_1) + + ml_model_2 = ml.Model(display_name=_random_identifier('TestModel123_list2_'), + tags=['test_tag123']) + model_2 = ml.create_model(model=ml_model_2) + + yield [model_1, model_2] + + _clean_up_model(model_1) + _clean_up_model(model_2) + + +def _clean_up_model(model): + try: + # Try to delete the model. + # Some tests delete the model as part of the test. + ml.delete_model(model.model_id) + except exceptions.NotFoundError: + pass + + +# For rpc errors +def check_firebase_error(excinfo, status, msg): + err = excinfo.value + assert isinstance(err, exceptions.FirebaseError) + assert err.cause is not None + assert err.http_response is not None + assert err.http_response.status_code == status + assert str(err) == msg + + +# For operation errors +def check_operation_error(excinfo, msg): + err = excinfo.value + assert isinstance(err, exceptions.FirebaseError) + assert str(err) == msg + + +def check_model(model, args): + assert model.display_name == args.get('display_name') + assert model.tags == args.get('tags') + assert model.model_id is not None + assert model.create_time is not None + assert model.update_time is not None + assert model.locked is False + assert model.etag is not None + + +def check_model_format(model, has_model_format=False, validation_error=None): + if has_model_format: + assert model.validation_error == validation_error + assert model.published is False + assert model.model_format.model_source.gcs_tflite_uri.startswith('gs://') + if validation_error: + assert model.model_format.size_bytes is None + assert model.model_hash is None + else: + assert model.model_format.size_bytes is not None + assert model.model_hash is not None + else: + assert model.model_format is None + assert model.validation_error == 'No model file has been uploaded.' + assert model.published is False + assert model.model_hash is None + + +@pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True) +def test_create_simple_model(firebase_model): + check_model(firebase_model, NAME_AND_TAGS_ARGS) + check_model_format(firebase_model) + + +@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) +def test_create_full_model(firebase_model): + check_model(firebase_model, FULL_MODEL_ARGS) + check_model_format(firebase_model, True) + + +@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) +def test_create_already_existing_fails(firebase_model): + with pytest.raises(exceptions.AlreadyExistsError) as excinfo: + ml.create_model(model=firebase_model) + check_operation_error( + excinfo, + 'Model \'{0}\' already exists'.format(firebase_model.display_name)) + + +@pytest.mark.parametrize('firebase_model', [INVALID_FULL_MODEL_ARGS], indirect=True) +def test_create_invalid_model(firebase_model): + check_model(firebase_model, INVALID_FULL_MODEL_ARGS) + check_model_format(firebase_model, True, 'Invalid flatbuffer format') + + +@pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True) +def test_get_model(firebase_model): + get_model = ml.get_model(firebase_model.model_id) + check_model(get_model, NAME_AND_TAGS_ARGS) + check_model_format(get_model) + + +@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) +def test_get_non_existing_model(firebase_model): + # Get a valid model_id that no longer exists + ml.delete_model(firebase_model.model_id) + + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.get_model(firebase_model.model_id) + check_firebase_error(excinfo, 404, 'Requested entity was not found.') + + +@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) +def test_update_model(firebase_model): + new_model_name = NAME_ONLY_ARGS_UPDATED.get('display_name') + firebase_model.display_name = new_model_name + updated_model = ml.update_model(firebase_model) + check_model(updated_model, NAME_ONLY_ARGS_UPDATED) + check_model_format(updated_model) + + # Second call with same model does not cause error + updated_model2 = ml.update_model(updated_model) + check_model(updated_model2, NAME_ONLY_ARGS_UPDATED) + check_model_format(updated_model2) + + +@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) +def test_update_non_existing_model(firebase_model): + ml.delete_model(firebase_model.model_id) + + firebase_model.tags = ['tag987'] + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.update_model(firebase_model) + check_operation_error( + excinfo, + 'Model \'{0}\' was not found'.format(firebase_model.as_dict().get('name'))) + + +@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) +def test_publish_unpublish_model(firebase_model): + assert firebase_model.published is False + + published_model = ml.publish_model(firebase_model.model_id) + assert published_model.published is True + + unpublished_model = ml.unpublish_model(published_model.model_id) + assert unpublished_model.published is False + + +@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) +def test_publish_invalid_fails(firebase_model): + assert firebase_model.validation_error is not None + + with pytest.raises(exceptions.FailedPreconditionError) as excinfo: + ml.publish_model(firebase_model.model_id) + check_operation_error( + excinfo, + 'Cannot publish a model that is not verified.') + + +@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) +def test_publish_unpublish_non_existing_model(firebase_model): + ml.delete_model(firebase_model.model_id) + + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.publish_model(firebase_model.model_id) + check_operation_error( + excinfo, + 'Model \'{0}\' was not found'.format(firebase_model.as_dict().get('name'))) + + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.unpublish_model(firebase_model.model_id) + check_operation_error( + excinfo, + 'Model \'{0}\' was not found'.format(firebase_model.as_dict().get('name'))) + + +def test_list_models(model_list): + filter_str = 'displayName={0} OR tags:{1}'.format( + model_list[0].display_name, model_list[1].tags[0]) + + all_models = ml.list_models(list_filter=filter_str) + all_model_ids = [mdl.model_id for mdl in all_models.iterate_all()] + for mdl in model_list: + assert mdl.model_id in all_model_ids + + +def test_list_models_invalid_filter(): + invalid_filter = 'InvalidFilterParam=123' + + with pytest.raises(exceptions.InvalidArgumentError) as excinfo: + ml.list_models(list_filter=invalid_filter) + check_firebase_error(excinfo, 400, 'Request contains an invalid argument.') + + +@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) +def test_delete_model(firebase_model): + ml.delete_model(firebase_model.model_id) + + # Second delete of same model will fail + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.delete_model(firebase_model.model_id) + check_firebase_error(excinfo, 404, 'Requested entity was not found.') + + +# Test tensor flow conversion functions if tensor flow is enabled. +#'pip install tensorflow' in the environment if you want _TF_ENABLED = True +#'pip install tensorflow==2.0.0b' for version 2 etc. + + +def _clean_up_directory(save_dir): + if save_dir.startswith(tempfile.gettempdir()) and os.path.exists(save_dir): + shutil.rmtree(save_dir) + + +@pytest.fixture +def keras_model(): + assert _TF_ENABLED + x_array = [-1, 0, 1, 2, 3, 4] + y_array = [-3, -1, 1, 3, 5, 7] + model = tf.keras.models.Sequential( + [tf.keras.layers.Dense(units=1, input_shape=[1])]) + model.compile(optimizer='sgd', loss='mean_squared_error') + model.fit(x_array, y_array, epochs=3) + return model + + +@pytest.fixture +def saved_model_dir(keras_model): + assert _TF_ENABLED + # Make a new parent directory. The child directory must not exist yet. + # The child directory gets created by tf. If it exists, the tf call fails. + parent = tempfile.mkdtemp() + save_dir = os.path.join(parent, 'child') + + # different versions have different model conversion capability + # pick something that works for each version + if tf.version.VERSION.startswith('1.'): + tf.reset_default_graph() + x_var = tf.placeholder(tf.float32, (None, 3), name="x") + y_var = tf.multiply(x_var, x_var, name="y") + with tf.Session() as sess: + tf.saved_model.simple_save(sess, save_dir, {"x": x_var}, {"y": y_var}) + else: + # If it's not version 1.x or version 2.x we need to update the test. + assert tf.version.VERSION.startswith('2.') + tf.saved_model.save(keras_model, save_dir) + yield save_dir + _clean_up_directory(parent) + + +@pytest.mark.skipif(not _TF_ENABLED, reason='Tensor flow is required for this test.') +def test_from_keras_model(keras_model): + source = ml.TFLiteGCSModelSource.from_keras_model(keras_model, 'model2.tflite') + assert re.search( + '^gs://.*/Firebase/ML/Models/model2.tflite$', + source.gcs_tflite_uri) is not None + + # Validate the conversion by creating a model + model_format = ml.TFLiteFormat(model_source=source) + model = ml.Model(display_name=_random_identifier('KerasModel_'), model_format=model_format) + created_model = ml.create_model(model) + + try: + check_model(created_model, {'display_name': model.display_name}) + check_model_format(created_model, True) + finally: + _clean_up_model(created_model) + + +@pytest.mark.skipif(not _TF_ENABLED, reason='Tensor flow is required for this test.') +def test_from_saved_model(saved_model_dir): + # Test the conversion helper + source = ml.TFLiteGCSModelSource.from_saved_model(saved_model_dir, 'model3.tflite') + assert re.search( + '^gs://.*/Firebase/ML/Models/model3.tflite$', + source.gcs_tflite_uri) is not None + + # Validate the conversion by creating a model + model_format = ml.TFLiteFormat(model_source=source) + model = ml.Model(display_name=_random_identifier('SavedModel_'), model_format=model_format) + created_model = ml.create_model(model) + + try: + assert created_model.model_id is not None + assert created_model.validation_error is None + finally: + _clean_up_model(created_model) diff --git a/tests/data/invalid_model.tflite b/tests/data/invalid_model.tflite new file mode 100644 index 000000000..d8482f436 --- /dev/null +++ b/tests/data/invalid_model.tflite @@ -0,0 +1 @@ +This is not a tflite file. diff --git a/tests/data/model1.tflite b/tests/data/model1.tflite new file mode 100644 index 0000000000000000000000000000000000000000..c4b71b7a222ebc59ee9fa1239fe2b8efb382cf8b GIT binary patch literal 736 zcmaJY5FPb2r$h~!B1MWTQ%GV^!A6@CK}ZNlustqhi-Y8H-iIg%{+S@QcK(Dk zR)Rl3tSqc7Y;=9^?h;aY@NQ;j_RY-O-KvOmPg{E;TT&H6Oeso9%7|7F5m^Gpi-MRS zFR}v^fd$|pw*l-X(CyeA%O3exDvVXXE-Q$g0Ea*gAfI&%;7x1IS|70Au#+FH4KSD! zSQ8%k6Eu29ZW(^Feo)_K>{n|Oma%PM==n~V_^~%s4thu4$j6N3nHtV(0aQiaEoyRp zezXMpUIWy`Jtl%{m~A>Qxh()kG2`sRkJM$N(Aph1%|>7Ok%DczaXT3_&XwE0a6`}S z4OAy+#G&g)!6;Io$rCiN=X3S*IL!O-tl7r`=KFA-Gt`c~_y(?gfqS2GxQ`s3wFx?^MfSDns>_^X1%E{Y8Sb)GfRIXJ-KXm39D=`^W;!7eZm6%(eLy;H^LSfV^(T? zeSA40uL6|QKPM`rbs3=!d?x$wt<-YMb0LrVras*CGoXmIndd!cZ>NyH9V}M=02G>W Ai~s-t literal 0 HcmV?d00001 From 7295ea458d688d26c3b9053072b0dd24f770dd33 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Thu, 30 Jan 2020 15:54:21 -0800 Subject: [PATCH 32/37] Fixing lint errors for Py3 (#401) * Fixing lint errors for Py3 * Removed dependency on six * Fixing a couple of merge errors --- firebase_admin/_auth_utils.py | 3 -- firebase_admin/_utils.py | 8 +-- firebase_admin/ml.py | 92 +++++++++++++++++------------------ integration/test_auth.py | 1 - integration/test_ml.py | 2 +- tests/test_db.py | 1 - tests/test_ml.py | 18 ++++--- tests/test_user_mgt.py | 6 --- 8 files changed, 59 insertions(+), 72 deletions(-) diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index 31c6d49ba..2f7383c0b 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -21,9 +21,6 @@ from firebase_admin import exceptions from firebase_admin import _utils -from firebase_admin import exceptions -from firebase_admin import _utils - MAX_CLAIMS_PAYLOAD_SIZE = 1000 RESERVED_CLAIMS = set([ diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index 529b74f5b..a5fc8d022 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -187,11 +187,11 @@ def handle_requests_error(error, message=None, code=None): return exceptions.DeadlineExceededError( message='Timed out while making an API call: {0}'.format(error), cause=error) - elif isinstance(error, requests.exceptions.ConnectionError): + if isinstance(error, requests.exceptions.ConnectionError): return exceptions.UnavailableError( message='Failed to establish a connection: {0}'.format(error), cause=error) - elif error.response is None: + if error.response is None: return exceptions.UnknownError( message='Unknown error while making a remote service call: {0}'.format(error), cause=error) @@ -274,11 +274,11 @@ def handle_googleapiclient_error(error, message=None, code=None, http_response=N return exceptions.DeadlineExceededError( message='Timed out while making an API call: {0}'.format(error), cause=error) - elif isinstance(error, httplib2.ServerNotFoundError): + if isinstance(error, httplib2.ServerNotFoundError): return exceptions.UnavailableError( message='Failed to establish a connection: {0}'.format(error), cause=error) - elif not isinstance(error, googleapiclient.errors.HttpError): + if not isinstance(error, googleapiclient.errors.HttpError): return exceptions.UnknownError( message='Unknown error while making a remote service call: {0}'.format(error), cause=error) diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index fa604933c..b7d8b818b 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -23,11 +23,10 @@ import re import time import os -import requests -import six +from urllib import parse +import requests -from six.moves import urllib from firebase_admin import _http_client from firebase_admin import _utils from firebase_admin import exceptions @@ -175,7 +174,7 @@ def delete_model(model_id, app=None): ml_service.delete_model(model_id) -class Model(object): +class Model: """A Firebase ML Model object. Args: @@ -218,8 +217,7 @@ def __eq__(self, other): if isinstance(other, self.__class__): # pylint: disable=protected-access return self._data == other._data and self._model_format == other._model_format - else: - return False + return False def __ne__(self, other): return not self.__eq__(other) @@ -341,7 +339,7 @@ def as_dict(self, for_upload=False): return copy -class ModelFormat(object): +class ModelFormat: """Abstract base class representing a Model Format such as TFLite.""" def as_dict(self, for_upload=False): """Returns a serializable representation of the object.""" @@ -378,8 +376,7 @@ def __eq__(self, other): if isinstance(other, self.__class__): # pylint: disable=protected-access return self._data == other._data and self._model_source == other._model_source - else: - return False + return False def __ne__(self, other): return not self.__eq__(other) @@ -409,14 +406,14 @@ def as_dict(self, for_upload=False): return {'tfliteModel': copy} -class TFLiteModelSource(object): +class TFLiteModelSource: """Abstract base class representing a model source for TFLite format models.""" def as_dict(self, for_upload=False): """Returns a serializable representation of the object.""" raise NotImplementedError -class _CloudStorageClient(object): +class _CloudStorageClient: """Cloud Storage helper class""" GCS_URI = 'gs://{0}/{1}' @@ -475,8 +472,7 @@ def __init__(self, gcs_tflite_uri, app=None): def __eq__(self, other): if isinstance(other, self.__class__): return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access - else: - return False + return False def __ne__(self, other): return not self.__eq__(other) @@ -517,15 +513,16 @@ def _tf_convert_from_saved_model(saved_model_dir): @staticmethod def _tf_convert_from_keras_model(keras_model): + """Converts the given Keras model into a TF Lite model.""" # Version 1.x conversion function takes a model file. Version 2.x takes the model itself. if tf.version.VERSION.startswith('1.'): keras_file = 'firebase_keras_model.h5' tf.keras.models.save_model(keras_model, keras_file) converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) - return converter.convert() else: converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) - return converter.convert() + + return converter.convert() @classmethod def from_saved_model(cls, saved_model_dir, model_file_name='firebase_ml_model.tflite', @@ -596,7 +593,7 @@ def as_dict(self, for_upload=False): return {'gcsTfliteUri': self._gcs_tflite_uri} -class ListModelsPage(object): +class ListModelsPage: """Represents a page of models in a firebase project. Provides methods for traversing the models included in this page, as well as @@ -662,7 +659,7 @@ def iterate_all(self): return _ModelIterator(self) -class _ModelIterator(object): +class _ModelIterator: """An iterator that allows iterating over models, one at a time. This implementation loads a page of models into memory, and iterates on them. @@ -730,7 +727,7 @@ def _validate_display_name(display_name): def _validate_tags(tags): if not isinstance(tags, list) or not \ - all(isinstance(tag, six.string_types) for tag in tags): + all(isinstance(tag, str) for tag in tags): raise TypeError('Tags must be a list of strings.') if not all(_TAG_PATTERN.match(tag) for tag in tags): raise ValueError('Tag format is invalid.') @@ -753,7 +750,7 @@ def _validate_model_format(model_format): def _validate_list_filter(list_filter): if list_filter is not None: - if not isinstance(list_filter, six.string_types): + if not isinstance(list_filter, str): raise TypeError('List filter must be a string or None.') @@ -769,11 +766,11 @@ def _validate_page_size(page_size): def _validate_page_token(page_token): if page_token is not None: - if not isinstance(page_token, six.string_types): + if not isinstance(page_token, str): raise TypeError('Page token must be a string or None.') -class _MLService(object): +class _MLService: """Firebase ML service.""" PROJECT_URL = 'https://mlkit.googleapis.com/v1beta1/projects/{0}/' @@ -811,8 +808,7 @@ def _exponential_backoff(self, current_attempt, stop_time): max_seconds_left = (stop_time - datetime.datetime.now()).total_seconds() if max_seconds_left < 1: # allow a bit of time for rpc raise exceptions.DeadlineExceededError('Polling max time exceeded.') - else: - wait_time_seconds = min(wait_time_seconds, max_seconds_left - 1) + wait_time_seconds = min(wait_time_seconds, max_seconds_left - 1) time.sleep(wait_time_seconds) def handle_operation(self, operation, wait_for_operation=False, max_time_seconds=None): @@ -831,6 +827,7 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds Raises: TypeError: if the operation is not a dictionary. ValueError: If the operation is malformed. + UnknownError: If the server responds with an unexpected response. err: If the operation exceeds polling attempts or stop_time """ if not isinstance(operation, dict): @@ -840,31 +837,31 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds # Operations which are immediately done don't have an operation name if operation.get('response'): return operation.get('response') - elif operation.get('error'): + if operation.get('error'): raise _utils.handle_operation_error(operation.get('error')) raise exceptions.UnknownError(message='Internal Error: Malformed Operation.') - else: - op_name = operation.get('name') - _, model_id = _validate_and_parse_operation_name(op_name) - current_attempt = 0 - start_time = datetime.datetime.now() - stop_time = (None if max_time_seconds is None else - start_time + datetime.timedelta(seconds=max_time_seconds)) - while wait_for_operation and not operation.get('done'): - # We just got this operation. Wait before getting another - # so we don't exceed the GetOperation maximum request rate. - self._exponential_backoff(current_attempt, stop_time) - operation = self.get_operation(op_name) - current_attempt += 1 - - if operation.get('done'): - if operation.get('response'): - return operation.get('response') - elif operation.get('error'): - raise _utils.handle_operation_error(operation.get('error')) - - # If the operation is not complete or timed out, return a (locked) model instead - return get_model(model_id).as_dict() + + op_name = operation.get('name') + _, model_id = _validate_and_parse_operation_name(op_name) + current_attempt = 0 + start_time = datetime.datetime.now() + stop_time = (None if max_time_seconds is None else + start_time + datetime.timedelta(seconds=max_time_seconds)) + while wait_for_operation and not operation.get('done'): + # We just got this operation. Wait before getting another + # so we don't exceed the GetOperation maximum request rate. + self._exponential_backoff(current_attempt, stop_time) + operation = self.get_operation(op_name) + current_attempt += 1 + + if operation.get('done'): + if operation.get('response'): + return operation.get('response') + if operation.get('error'): + raise _utils.handle_operation_error(operation.get('error')) + + # If the operation is not complete or timed out, return a (locked) model instead + return get_model(model_id).as_dict() def create_model(self, model): @@ -918,8 +915,7 @@ def list_models(self, list_filter, page_size, page_token): params['page_token'] = page_token path = 'models' if params: - # pylint: disable=too-many-function-args - param_str = urllib.parse.urlencode(sorted(params.items()), True) + param_str = parse.urlencode(sorted(params.items()), True) path = path + '?' + param_str try: return self._client.body('get', url=path) diff --git a/integration/test_auth.py b/integration/test_auth.py index b4c5727b3..5d26dd9f1 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -30,7 +30,6 @@ from firebase_admin import credentials - _verify_token_url = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/verifyCustomToken' _verify_password_url = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/verifyPassword' _password_reset_url = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/resetPassword' diff --git a/integration/test_ml.py b/integration/test_ml.py index 4c44f3d10..be791d8fa 100644 --- a/integration/test_ml.py +++ b/integration/test_ml.py @@ -19,8 +19,8 @@ import shutil import string import tempfile -import pytest +import pytest from firebase_admin import exceptions from firebase_admin import ml diff --git a/tests/test_db.py b/tests/test_db.py index ce9ea194c..1743347c5 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -729,7 +729,6 @@ def test_parse_db_url_errors(self, url, emulator_host): @pytest.mark.parametrize('url', [ 'https://test.firebaseio.com', 'https://test.firebaseio.com/' ]) - @pytest.mark.skip(reason='only skip until mlkit branch is synced with master') def test_valid_db_url(self, url): firebase_admin.initialize_app(testutils.MockCredential(), {'databaseURL' : url}) ref = db.reference() diff --git a/tests/test_ml.py b/tests/test_ml.py index e91517dd5..6accba1cb 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -15,6 +15,7 @@ """Test cases for the firebase_admin.ml module.""" import json + import pytest import firebase_admin @@ -22,6 +23,7 @@ from firebase_admin import ml from tests import testutils + BASE_URL = 'https://mlkit.googleapis.com/v1beta1/' PROJECT_ID = 'myProject1' PAGE_TOKEN = 'pageToken' @@ -319,7 +321,7 @@ def instrument_ml_service(status=200, payload=None, operations=False, app=None): session_url, adapter(payload, status, recorder)) return recorder -class _TestStorageClient(object): +class _TestStorageClient: @staticmethod def upload(bucket_name, model_file_name, app): del app # unused variable @@ -332,7 +334,7 @@ def sign_uri(gcs_tflite_uri, app): bucket_name, blob_name = ml._CloudStorageClient._parse_gcs_tflite_uri(gcs_tflite_uri) return GCS_TFLITE_SIGNED_URI_PATTERN.format(bucket_name, blob_name) -class TestModel(object): +class TestModel: """Tests ml.Model class.""" @classmethod def setup_class(cls): @@ -545,7 +547,7 @@ def test_wait_for_unlocked_timeout(self): assert len(recorder) == 1 -class TestCreateModel(object): +class TestCreateModel: """Tests ml.create_model.""" @classmethod def setup_class(cls): @@ -641,7 +643,7 @@ def test_invalid_op_name(self, op_name): check_error(excinfo, ValueError, 'Operation name format is invalid.') -class TestUpdateModel(object): +class TestUpdateModel: """Tests ml.update_model.""" @classmethod def setup_class(cls): @@ -733,7 +735,7 @@ def test_invalid_op_name(self, op_name): check_error(excinfo, ValueError, 'Operation name format is invalid.') -class TestPublishUnpublish(object): +class TestPublishUnpublish: """Tests ml.publish_model and ml.unpublish_model.""" PUBLISH_UNPUBLISH_WITH_ARGS = [ @@ -823,7 +825,7 @@ def test_rpc_error(self, publish_function): assert len(create_recorder) == 1 -class TestGetModel(object): +class TestGetModel: """Tests ml.get_model.""" @classmethod def setup_class(cls): @@ -876,7 +878,7 @@ def evaluate(): testutils.run_without_project_id(evaluate) -class TestDeleteModel(object): +class TestDeleteModel: """Tests ml.delete_model.""" @classmethod def setup_class(cls): @@ -926,7 +928,7 @@ def evaluate(): testutils.run_without_project_id(evaluate) -class TestListModels(object): +class TestListModels: """Tests ml.list_models.""" @classmethod def setup_class(cls): diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index b8da06bbc..958bbf9c4 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -310,12 +310,6 @@ class TestCreateUser: 'PHONE_NUMBER_EXISTS': auth.PhoneNumberAlreadyExistsError, } - already_exists_errors = { - 'DUPLICATE_EMAIL': auth.EmailAlreadyExistsError, - 'DUPLICATE_LOCAL_ID': auth.UidAlreadyExistsError, - 'PHONE_NUMBER_EXISTS': auth.PhoneNumberAlreadyExistsError, - } - @pytest.mark.parametrize('arg', INVALID_STRINGS[1:] + ['a'*129]) def test_invalid_uid(self, user_mgt_app, arg): with pytest.raises(ValueError): From bcefca8d34135113f2ec72f2424b3bc909d95d4a Mon Sep 17 00:00:00 2001 From: ifielker Date: Fri, 20 Mar 2020 14:26:54 -0400 Subject: [PATCH 33/37] Modifying operation handling to support backend changes (#423) * modifying operation handling to support backend changes --- firebase_admin/ml.py | 22 ++++++++++++---------- tests/test_ml.py | 32 ++++++++++++++++---------------- 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index b7d8b818b..d6c14c7ac 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -53,10 +53,9 @@ _GCS_TFLITE_URI_PATTERN = re.compile( r'^gs://(?P[a-z0-9_.-]{3,63})/(?P.+)$') _RESOURCE_NAME_PATTERN = re.compile( - r'^projects/(?P[^/]+)/models/(?P[A-Za-z0-9_-]{1,60})$') + r'^projects/(?P[a-z0-9-]{6,30})/models/(?P[A-Za-z0-9_-]{1,60})$') _OPERATION_NAME_PATTERN = re.compile( - r'^operations/project/(?P[^/]+)/model/(?P[A-Za-z0-9_-]{1,60})' + - r'/operation/[^/]+$') + r'^projects/(?P[a-z0-9-]{6,30})/operations/[^/]+$') def _get_ml_service(app): @@ -712,11 +711,10 @@ def _validate_model_id(model_id): raise ValueError('Model ID format is invalid.') -def _validate_and_parse_operation_name(op_name): - matcher = _OPERATION_NAME_PATTERN.match(op_name) - if not matcher: +def _validate_operation_name(op_name): + if not _OPERATION_NAME_PATTERN.match(op_name): raise ValueError('Operation name format is invalid.') - return matcher.group('project_id'), matcher.group('model_id') + return op_name def _validate_display_name(display_name): @@ -793,7 +791,7 @@ def __init__(self, app): base_url=_MLService.OPERATION_URL) def get_operation(self, op_name): - _validate_and_parse_operation_name(op_name) + _validate_operation_name(op_name) try: return self._operation_client.body('get', url=op_name) except requests.exceptions.RequestException as error: @@ -841,8 +839,12 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds raise _utils.handle_operation_error(operation.get('error')) raise exceptions.UnknownError(message='Internal Error: Malformed Operation.') - op_name = operation.get('name') - _, model_id = _validate_and_parse_operation_name(op_name) + op_name = _validate_operation_name(operation.get('name')) + metadata = operation.get('metadata', {}) + metadata_type = metadata.get('@type', '') + if not metadata_type.endswith('ModelOperationMetadata'): + raise TypeError('Unknown type of operation metadata.') + _, model_id = _validate_and_parse_name(metadata.get('name')) current_attempt = 0 start_time = datetime.datetime.now() stop_time = (None if max_time_seconds is None else diff --git a/tests/test_ml.py b/tests/test_ml.py index 6accba1cb..439aeaac3 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -25,7 +25,7 @@ BASE_URL = 'https://mlkit.googleapis.com/v1beta1/' -PROJECT_ID = 'myProject1' +PROJECT_ID = 'my-project-1' PAGE_TOKEN = 'pageToken' NEXT_PAGE_TOKEN = 'nextPageToken' @@ -85,11 +85,11 @@ } } -OPERATION_NAME_1 = 'operations/project/{0}/model/{1}/operation/123'.format(PROJECT_ID, MODEL_ID_1) +OPERATION_NAME_1 = 'projects/{0}/operations/123'.format(PROJECT_ID) OPERATION_NOT_DONE_JSON_1 = { 'name': OPERATION_NAME_1, 'metadata': { - '@type': 'type.googleapis.com/google.firebase.ml.v1beta1.ModelOperationMetadata', + '@type': 'type.googleapis.com/google.firebase.ml.v1beta2.ModelOperationMetadata', 'name': 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1), 'basic_operation_status': 'BASIC_OPERATION_STATUS_UPLOADING' } @@ -265,10 +265,10 @@ INVALID_OP_NAME_ARGS = [ 'abc', '123', - 'projects/operations/project/1234/model/abc/operation/123', - 'operations/project/model/abc/operation/123', - 'operations/project/123/model/$#@/operation/123', - 'operations/project/1234/model/abc/operation/123/extrathing', + 'operations/project/1234/model/abc/operation/123', + 'projects/operations/123', + 'projects/$#@/operations/123', + 'projects/1234/operations/123/extrathing', ] PAGE_SIZE_VALUE_ERROR_MSG = 'Page size must be a positive integer between ' \ '1 and {0}'.format(ml._MAX_PAGE_SIZE) @@ -348,9 +348,9 @@ def teardown_class(cls): testutils.cleanup_apps() @staticmethod - def _op_url(project_id, model_id): + def _op_url(project_id): return BASE_URL + \ - 'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id) + 'projects/{0}/operations/123'.format(project_id) def test_model_success_err_state_lro(self): model = ml.Model.from_dict(FULL_MODEL_ERR_STATE_LRO_JSON) @@ -534,7 +534,7 @@ def test_wait_for_unlocked(self): assert model == FULL_MODEL_PUBLISHED assert len(recorder) == 1 assert recorder[0].method == 'GET' - assert recorder[0].url == TestModel._op_url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].url == TestModel._op_url(PROJECT_ID) def test_wait_for_unlocked_timeout(self): recorder = instrument_ml_service( @@ -564,9 +564,9 @@ def _url(project_id): return BASE_URL + 'projects/{0}/models'.format(project_id) @staticmethod - def _op_url(project_id, model_id): + def _op_url(project_id): return BASE_URL + \ - 'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id) + 'projects/{0}/operations/123'.format(project_id) @staticmethod def _get_url(project_id, model_id): @@ -660,9 +660,9 @@ def _url(project_id, model_id): return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) @staticmethod - def _op_url(project_id, model_id): + def _op_url(project_id): return BASE_URL + \ - 'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id) + 'projects/{0}/operations/123'.format(project_id) def test_immediate_done(self): instrument_ml_service(status=200, payload=OPERATION_DONE_RESPONSE) @@ -765,9 +765,9 @@ def _get_url(project_id, model_id): return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) @staticmethod - def _op_url(project_id, model_id): + def _op_url(project_id): return BASE_URL + \ - 'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id) + 'projects/{0}/operations/123'.format(project_id) @pytest.mark.parametrize('publish_function, published', PUBLISH_UNPUBLISH_WITH_ARGS) def test_immediate_done(self, publish_function, published): From e49add8c391a3cadde8d53d139d6a5a5f9f1f975 Mon Sep 17 00:00:00 2001 From: ifielker Date: Fri, 20 Mar 2020 14:46:02 -0400 Subject: [PATCH 34/37] Firebase ML Changing service endpoint (#421) --- firebase_admin/ml.py | 4 ++-- tests/test_ml.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index d6c14c7ac..06429b5a1 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -771,8 +771,8 @@ def _validate_page_token(page_token): class _MLService: """Firebase ML service.""" - PROJECT_URL = 'https://mlkit.googleapis.com/v1beta1/projects/{0}/' - OPERATION_URL = 'https://mlkit.googleapis.com/v1beta1/' + PROJECT_URL = 'https://firebaseml.googleapis.com/v1beta2/projects/{0}/' + OPERATION_URL = 'https://firebaseml.googleapis.com/v1beta2/' POLL_EXPONENTIAL_BACKOFF_FACTOR = 1.5 POLL_BASE_WAIT_TIME_SECONDS = 3 diff --git a/tests/test_ml.py b/tests/test_ml.py index 439aeaac3..8813792e6 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -24,8 +24,9 @@ from tests import testutils -BASE_URL = 'https://mlkit.googleapis.com/v1beta1/' +BASE_URL = 'https://firebaseml.googleapis.com/v1beta2/' PROJECT_ID = 'my-project-1' + PAGE_TOKEN = 'pageToken' NEXT_PAGE_TOKEN = 'nextPageToken' @@ -306,7 +307,7 @@ def instrument_ml_service(status=200, payload=None, operations=False, app=None): app = firebase_admin.get_app() ml_service = ml._get_ml_service(app) recorder = [] - session_url = 'https://mlkit.googleapis.com/v1beta1/' + session_url = 'https://firebaseml.googleapis.com/v1beta2/' if isinstance(status, list): adapter = testutils.MockMultiRequestAdapter From f90a021eacf170950e7c87a21ba3f319ee88ec5c Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 14 Apr 2020 19:20:06 -0400 Subject: [PATCH 35/37] Mlkit add headers (#445) * add Headers --- firebase_admin/ml.py | 6 ++++++ tests/test_ml.py | 17 +++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index 06429b5a1..db1657839 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -27,6 +27,7 @@ import requests +import firebase_admin from firebase_admin import _http_client from firebase_admin import _utils from firebase_admin import exceptions @@ -783,11 +784,16 @@ def __init__(self, app): 'Project ID is required to access ML service. Either set the ' 'projectId option, or use service account credentials.') self._project_url = _MLService.PROJECT_URL.format(self._project_id) + ml_headers = { + 'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__), + } self._client = _http_client.JsonHttpClient( credential=app.credential.get_credential(), + headers=ml_headers, base_url=self._project_url) self._operation_client = _http_client.JsonHttpClient( credential=app.credential.get_credential(), + headers=ml_headers, base_url=_MLService.OPERATION_URL) def get_operation(self, op_name): diff --git a/tests/test_ml.py b/tests/test_ml.py index 8813792e6..46f36e50a 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -25,6 +25,8 @@ BASE_URL = 'https://firebaseml.googleapis.com/v1beta2/' +HEADER_CLIENT_KEY = 'X-FIREBASE-CLIENT' +HEADER_CLIENT_VALUE = 'fire-admin-python/3.2.1' PROJECT_ID = 'my-project-1' PAGE_TOKEN = 'pageToken' @@ -536,6 +538,7 @@ def test_wait_for_unlocked(self): assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == TestModel._op_url(PROJECT_ID) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE def test_wait_for_unlocked_timeout(self): recorder = instrument_ml_service( @@ -589,8 +592,10 @@ def test_returns_locked(self): assert len(recorder) == 2 assert recorder[0].method == 'POST' assert recorder[0].url == TestCreateModel._url(PROJECT_ID) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE assert recorder[1].method == 'GET' assert recorder[1].url == TestCreateModel._get_url(PROJECT_ID, MODEL_ID_1) + assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE def test_operation_error(self): instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE) @@ -681,8 +686,10 @@ def test_returns_locked(self): assert len(recorder) == 2 assert recorder[0].method == 'PATCH' assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE assert recorder[1].method == 'GET' assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE def test_operation_error(self): instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE) @@ -778,6 +785,7 @@ def test_immediate_done(self, publish_function, published): assert len(recorder) == 1 assert recorder[0].method == 'PATCH' assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE body = json.loads(recorder[0].body.decode()) assert body.get('state', {}).get('published', None) is published @@ -793,8 +801,10 @@ def test_returns_locked(self, publish_function): assert len(recorder) == 2 assert recorder[0].method == 'PATCH' assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE assert recorder[1].method == 'GET' assert recorder[1].url == TestPublishUnpublish._get_url(PROJECT_ID, MODEL_ID_1) + assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_operation_error(self, publish_function): @@ -847,6 +857,7 @@ def test_get_model(self): assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE assert model == MODEL_1 assert model.model_id == MODEL_ID_1 assert model.display_name == DISPLAY_NAME_1 @@ -870,6 +881,7 @@ def test_get_model_error(self): assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE def test_no_project_id(self): def evaluate(): @@ -900,6 +912,7 @@ def test_delete_model(self): assert len(recorder) == 1 assert recorder[0].method == 'DELETE' assert recorder[0].url == TestDeleteModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE @pytest.mark.parametrize('model_id, exc_type', INVALID_MODEL_ID_ARGS) def test_delete_model_validation_errors(self, model_id, exc_type): @@ -920,6 +933,7 @@ def test_delete_model_error(self): assert len(recorder) == 1 assert recorder[0].method == 'DELETE' assert recorder[0].url == self._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE def test_no_project_id(self): def evaluate(): @@ -957,6 +971,7 @@ def test_list_models_no_args(self): assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == TestListModels._url(PROJECT_ID) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE TestListModels._check_page(models_page, 2) assert models_page.has_next_page assert models_page.next_page_token == NEXT_PAGE_TOKEN @@ -975,6 +990,7 @@ def test_list_models_with_all_args(self): TestListModels._url(PROJECT_ID) + '?filter=display_name%3DdisplayName3&page_size=10&page_token={0}' .format(PAGE_TOKEN)) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE assert isinstance(models_page, ml.ListModelsPage) assert len(models_page.models) == 1 assert models_page.models[0] == MODEL_3 @@ -1020,6 +1036,7 @@ def test_list_models_error(self): assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == TestListModels._url(PROJECT_ID) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE def test_no_project_id(self): def evaluate(): From 5521fd41a0226a2c3f6794ec66e03f5b53439c83 Mon Sep 17 00:00:00 2001 From: ifielker Date: Wed, 15 Apr 2020 13:12:34 -0400 Subject: [PATCH 36/37] fixed test (#448) --- tests/test_ml.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_ml.py b/tests/test_ml.py index 46f36e50a..10b0441db 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -26,7 +26,7 @@ BASE_URL = 'https://firebaseml.googleapis.com/v1beta2/' HEADER_CLIENT_KEY = 'X-FIREBASE-CLIENT' -HEADER_CLIENT_VALUE = 'fire-admin-python/3.2.1' +HEADER_CLIENT_VALUE = 'fire-admin-python/{0}'.format(firebase_admin.__version__) PROJECT_ID = 'my-project-1' PAGE_TOKEN = 'pageToken' From 564078dd436446e26d9cd79d338ba41bb361a8ba Mon Sep 17 00:00:00 2001 From: ifielker Date: Fri, 17 Apr 2020 11:54:12 -0400 Subject: [PATCH 37/37] Adding tensorflow and keras so we don't skip tests (#449) * Adding tensorflow and keras so we don't skip tests * Add additional instructions for integration tests for ml --- .github/workflows/release.yml | 2 ++ CONTRIBUTING.md | 9 ++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 6d626eef2..64ee304ce 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -54,6 +54,8 @@ jobs: python -m pip install --upgrade pip pip install -r requirements.txt pip install setuptools wheel + pip install tensorflow + pip install keras - name: Run unit tests run: pytest diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 80a607a8d..f6d09b093 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -183,14 +183,17 @@ Then set up your Firebase/GCP project as follows: Firebase Console. Select the "Sign-in method" tab, and enable the "Email/Password" sign-in method, including the Email link (passwordless sign-in) option. - -3. Enable the IAM API: Go to the +3. Enable the Firebase ML API: Go to the + [Google Developers Console]( + https://console.developers.google.com/apis/api/firebaseml.googleapis.com/overview) + and make sure your project is selected. If the API is not already enabled, click Enable. +4. Enable the IAM API: Go to the [Google Cloud Platform Console](https://console.cloud.google.com) and make sure your Firebase/GCP project is selected. Select "APIs & Services > Dashboard" from the main menu, and click the "ENABLE APIS AND SERVICES" button. Search for and enable the "Identity and Access Management (IAM) API". -4. Grant your service account the 'Firebase Authentication Admin' role. This is +5. Grant your service account the 'Firebase Authentication Admin' role. This is required to ensure that exported user records contain the password hashes of the user accounts: 1. Go to [Google Cloud Platform Console / IAM & admin](https://console.cloud.google.com/iam-admin).