diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index 31c6d49ba..2f7383c0b 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -21,9 +21,6 @@ from firebase_admin import exceptions from firebase_admin import _utils -from firebase_admin import exceptions -from firebase_admin import _utils - MAX_CLAIMS_PAYLOAD_SIZE = 1000 RESERVED_CLAIMS = set([ diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index 529b74f5b..a5fc8d022 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -187,11 +187,11 @@ def handle_requests_error(error, message=None, code=None): return exceptions.DeadlineExceededError( message='Timed out while making an API call: {0}'.format(error), cause=error) - elif isinstance(error, requests.exceptions.ConnectionError): + if isinstance(error, requests.exceptions.ConnectionError): return exceptions.UnavailableError( message='Failed to establish a connection: {0}'.format(error), cause=error) - elif error.response is None: + if error.response is None: return exceptions.UnknownError( message='Unknown error while making a remote service call: {0}'.format(error), cause=error) @@ -274,11 +274,11 @@ def handle_googleapiclient_error(error, message=None, code=None, http_response=N return exceptions.DeadlineExceededError( message='Timed out while making an API call: {0}'.format(error), cause=error) - elif isinstance(error, httplib2.ServerNotFoundError): + if isinstance(error, httplib2.ServerNotFoundError): return exceptions.UnavailableError( message='Failed to establish a connection: {0}'.format(error), cause=error) - elif not isinstance(error, googleapiclient.errors.HttpError): + if not isinstance(error, googleapiclient.errors.HttpError): return exceptions.UnknownError( message='Unknown error while making a remote service call: {0}'.format(error), cause=error) diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index fa604933c..b7d8b818b 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -23,11 +23,10 @@ import re import time import os -import requests -import six +from urllib import parse +import requests -from six.moves import urllib from firebase_admin import _http_client from firebase_admin import _utils from firebase_admin import exceptions @@ -175,7 +174,7 @@ def delete_model(model_id, app=None): ml_service.delete_model(model_id) -class Model(object): +class Model: """A Firebase ML Model object. Args: @@ -218,8 +217,7 @@ def __eq__(self, other): if isinstance(other, self.__class__): # pylint: disable=protected-access return self._data == other._data and self._model_format == other._model_format - else: - return False + return False def __ne__(self, other): return not self.__eq__(other) @@ -341,7 +339,7 @@ def as_dict(self, for_upload=False): return copy -class ModelFormat(object): +class ModelFormat: """Abstract base class representing a Model Format such as TFLite.""" def as_dict(self, for_upload=False): """Returns a serializable representation of the object.""" @@ -378,8 +376,7 @@ def __eq__(self, other): if isinstance(other, self.__class__): # pylint: disable=protected-access return self._data == other._data and self._model_source == other._model_source - else: - return False + return False def __ne__(self, other): return not self.__eq__(other) @@ -409,14 +406,14 @@ def as_dict(self, for_upload=False): return {'tfliteModel': copy} -class TFLiteModelSource(object): +class TFLiteModelSource: """Abstract base class representing a model source for TFLite format models.""" def as_dict(self, for_upload=False): """Returns a serializable representation of the object.""" raise NotImplementedError -class _CloudStorageClient(object): +class _CloudStorageClient: """Cloud Storage helper class""" GCS_URI = 'gs://{0}/{1}' @@ -475,8 +472,7 @@ def __init__(self, gcs_tflite_uri, app=None): def __eq__(self, other): if isinstance(other, self.__class__): return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access - else: - return False + return False def __ne__(self, other): return not self.__eq__(other) @@ -517,15 +513,16 @@ def _tf_convert_from_saved_model(saved_model_dir): @staticmethod def _tf_convert_from_keras_model(keras_model): + """Converts the given Keras model into a TF Lite model.""" # Version 1.x conversion function takes a model file. Version 2.x takes the model itself. if tf.version.VERSION.startswith('1.'): keras_file = 'firebase_keras_model.h5' tf.keras.models.save_model(keras_model, keras_file) converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) - return converter.convert() else: converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) - return converter.convert() + + return converter.convert() @classmethod def from_saved_model(cls, saved_model_dir, model_file_name='firebase_ml_model.tflite', @@ -596,7 +593,7 @@ def as_dict(self, for_upload=False): return {'gcsTfliteUri': self._gcs_tflite_uri} -class ListModelsPage(object): +class ListModelsPage: """Represents a page of models in a firebase project. Provides methods for traversing the models included in this page, as well as @@ -662,7 +659,7 @@ def iterate_all(self): return _ModelIterator(self) -class _ModelIterator(object): +class _ModelIterator: """An iterator that allows iterating over models, one at a time. This implementation loads a page of models into memory, and iterates on them. @@ -730,7 +727,7 @@ def _validate_display_name(display_name): def _validate_tags(tags): if not isinstance(tags, list) or not \ - all(isinstance(tag, six.string_types) for tag in tags): + all(isinstance(tag, str) for tag in tags): raise TypeError('Tags must be a list of strings.') if not all(_TAG_PATTERN.match(tag) for tag in tags): raise ValueError('Tag format is invalid.') @@ -753,7 +750,7 @@ def _validate_model_format(model_format): def _validate_list_filter(list_filter): if list_filter is not None: - if not isinstance(list_filter, six.string_types): + if not isinstance(list_filter, str): raise TypeError('List filter must be a string or None.') @@ -769,11 +766,11 @@ def _validate_page_size(page_size): def _validate_page_token(page_token): if page_token is not None: - if not isinstance(page_token, six.string_types): + if not isinstance(page_token, str): raise TypeError('Page token must be a string or None.') -class _MLService(object): +class _MLService: """Firebase ML service.""" PROJECT_URL = 'https://mlkit.googleapis.com/v1beta1/projects/{0}/' @@ -811,8 +808,7 @@ def _exponential_backoff(self, current_attempt, stop_time): max_seconds_left = (stop_time - datetime.datetime.now()).total_seconds() if max_seconds_left < 1: # allow a bit of time for rpc raise exceptions.DeadlineExceededError('Polling max time exceeded.') - else: - wait_time_seconds = min(wait_time_seconds, max_seconds_left - 1) + wait_time_seconds = min(wait_time_seconds, max_seconds_left - 1) time.sleep(wait_time_seconds) def handle_operation(self, operation, wait_for_operation=False, max_time_seconds=None): @@ -831,6 +827,7 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds Raises: TypeError: if the operation is not a dictionary. ValueError: If the operation is malformed. + UnknownError: If the server responds with an unexpected response. err: If the operation exceeds polling attempts or stop_time """ if not isinstance(operation, dict): @@ -840,31 +837,31 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds # Operations which are immediately done don't have an operation name if operation.get('response'): return operation.get('response') - elif operation.get('error'): + if operation.get('error'): raise _utils.handle_operation_error(operation.get('error')) raise exceptions.UnknownError(message='Internal Error: Malformed Operation.') - else: - op_name = operation.get('name') - _, model_id = _validate_and_parse_operation_name(op_name) - current_attempt = 0 - start_time = datetime.datetime.now() - stop_time = (None if max_time_seconds is None else - start_time + datetime.timedelta(seconds=max_time_seconds)) - while wait_for_operation and not operation.get('done'): - # We just got this operation. Wait before getting another - # so we don't exceed the GetOperation maximum request rate. - self._exponential_backoff(current_attempt, stop_time) - operation = self.get_operation(op_name) - current_attempt += 1 - - if operation.get('done'): - if operation.get('response'): - return operation.get('response') - elif operation.get('error'): - raise _utils.handle_operation_error(operation.get('error')) - - # If the operation is not complete or timed out, return a (locked) model instead - return get_model(model_id).as_dict() + + op_name = operation.get('name') + _, model_id = _validate_and_parse_operation_name(op_name) + current_attempt = 0 + start_time = datetime.datetime.now() + stop_time = (None if max_time_seconds is None else + start_time + datetime.timedelta(seconds=max_time_seconds)) + while wait_for_operation and not operation.get('done'): + # We just got this operation. Wait before getting another + # so we don't exceed the GetOperation maximum request rate. + self._exponential_backoff(current_attempt, stop_time) + operation = self.get_operation(op_name) + current_attempt += 1 + + if operation.get('done'): + if operation.get('response'): + return operation.get('response') + if operation.get('error'): + raise _utils.handle_operation_error(operation.get('error')) + + # If the operation is not complete or timed out, return a (locked) model instead + return get_model(model_id).as_dict() def create_model(self, model): @@ -918,8 +915,7 @@ def list_models(self, list_filter, page_size, page_token): params['page_token'] = page_token path = 'models' if params: - # pylint: disable=too-many-function-args - param_str = urllib.parse.urlencode(sorted(params.items()), True) + param_str = parse.urlencode(sorted(params.items()), True) path = path + '?' + param_str try: return self._client.body('get', url=path) diff --git a/integration/test_auth.py b/integration/test_auth.py index b4c5727b3..5d26dd9f1 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -30,7 +30,6 @@ from firebase_admin import credentials - _verify_token_url = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/verifyCustomToken' _verify_password_url = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/verifyPassword' _password_reset_url = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/resetPassword' diff --git a/integration/test_ml.py b/integration/test_ml.py index 4c44f3d10..be791d8fa 100644 --- a/integration/test_ml.py +++ b/integration/test_ml.py @@ -19,8 +19,8 @@ import shutil import string import tempfile -import pytest +import pytest from firebase_admin import exceptions from firebase_admin import ml diff --git a/tests/test_db.py b/tests/test_db.py index ce9ea194c..1743347c5 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -729,7 +729,6 @@ def test_parse_db_url_errors(self, url, emulator_host): @pytest.mark.parametrize('url', [ 'https://test.firebaseio.com', 'https://test.firebaseio.com/' ]) - @pytest.mark.skip(reason='only skip until mlkit branch is synced with master') def test_valid_db_url(self, url): firebase_admin.initialize_app(testutils.MockCredential(), {'databaseURL' : url}) ref = db.reference() diff --git a/tests/test_ml.py b/tests/test_ml.py index e91517dd5..6accba1cb 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -15,6 +15,7 @@ """Test cases for the firebase_admin.ml module.""" import json + import pytest import firebase_admin @@ -22,6 +23,7 @@ from firebase_admin import ml from tests import testutils + BASE_URL = 'https://mlkit.googleapis.com/v1beta1/' PROJECT_ID = 'myProject1' PAGE_TOKEN = 'pageToken' @@ -319,7 +321,7 @@ def instrument_ml_service(status=200, payload=None, operations=False, app=None): session_url, adapter(payload, status, recorder)) return recorder -class _TestStorageClient(object): +class _TestStorageClient: @staticmethod def upload(bucket_name, model_file_name, app): del app # unused variable @@ -332,7 +334,7 @@ def sign_uri(gcs_tflite_uri, app): bucket_name, blob_name = ml._CloudStorageClient._parse_gcs_tflite_uri(gcs_tflite_uri) return GCS_TFLITE_SIGNED_URI_PATTERN.format(bucket_name, blob_name) -class TestModel(object): +class TestModel: """Tests ml.Model class.""" @classmethod def setup_class(cls): @@ -545,7 +547,7 @@ def test_wait_for_unlocked_timeout(self): assert len(recorder) == 1 -class TestCreateModel(object): +class TestCreateModel: """Tests ml.create_model.""" @classmethod def setup_class(cls): @@ -641,7 +643,7 @@ def test_invalid_op_name(self, op_name): check_error(excinfo, ValueError, 'Operation name format is invalid.') -class TestUpdateModel(object): +class TestUpdateModel: """Tests ml.update_model.""" @classmethod def setup_class(cls): @@ -733,7 +735,7 @@ def test_invalid_op_name(self, op_name): check_error(excinfo, ValueError, 'Operation name format is invalid.') -class TestPublishUnpublish(object): +class TestPublishUnpublish: """Tests ml.publish_model and ml.unpublish_model.""" PUBLISH_UNPUBLISH_WITH_ARGS = [ @@ -823,7 +825,7 @@ def test_rpc_error(self, publish_function): assert len(create_recorder) == 1 -class TestGetModel(object): +class TestGetModel: """Tests ml.get_model.""" @classmethod def setup_class(cls): @@ -876,7 +878,7 @@ def evaluate(): testutils.run_without_project_id(evaluate) -class TestDeleteModel(object): +class TestDeleteModel: """Tests ml.delete_model.""" @classmethod def setup_class(cls): @@ -926,7 +928,7 @@ def evaluate(): testutils.run_without_project_id(evaluate) -class TestListModels(object): +class TestListModels: """Tests ml.list_models.""" @classmethod def setup_class(cls): diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index b8da06bbc..958bbf9c4 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -310,12 +310,6 @@ class TestCreateUser: 'PHONE_NUMBER_EXISTS': auth.PhoneNumberAlreadyExistsError, } - already_exists_errors = { - 'DUPLICATE_EMAIL': auth.EmailAlreadyExistsError, - 'DUPLICATE_LOCAL_ID': auth.UidAlreadyExistsError, - 'PHONE_NUMBER_EXISTS': auth.PhoneNumberAlreadyExistsError, - } - @pytest.mark.parametrize('arg', INVALID_STRINGS[1:] + ['a'*129]) def test_invalid_uid(self, user_mgt_app, arg): with pytest.raises(ValueError):