Skip to content

Fixing lint errors for Py3 #401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions firebase_admin/_auth_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@
from firebase_admin import exceptions
from firebase_admin import _utils

from firebase_admin import exceptions
from firebase_admin import _utils


MAX_CLAIMS_PAYLOAD_SIZE = 1000
RESERVED_CLAIMS = set([
Expand Down
8 changes: 4 additions & 4 deletions firebase_admin/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,11 @@ def handle_requests_error(error, message=None, code=None):
return exceptions.DeadlineExceededError(
message='Timed out while making an API call: {0}'.format(error),
cause=error)
elif isinstance(error, requests.exceptions.ConnectionError):
if isinstance(error, requests.exceptions.ConnectionError):
return exceptions.UnavailableError(
message='Failed to establish a connection: {0}'.format(error),
cause=error)
elif error.response is None:
if error.response is None:
return exceptions.UnknownError(
message='Unknown error while making a remote service call: {0}'.format(error),
cause=error)
Expand Down Expand Up @@ -274,11 +274,11 @@ def handle_googleapiclient_error(error, message=None, code=None, http_response=N
return exceptions.DeadlineExceededError(
message='Timed out while making an API call: {0}'.format(error),
cause=error)
elif isinstance(error, httplib2.ServerNotFoundError):
if isinstance(error, httplib2.ServerNotFoundError):
return exceptions.UnavailableError(
message='Failed to establish a connection: {0}'.format(error),
cause=error)
elif not isinstance(error, googleapiclient.errors.HttpError):
if not isinstance(error, googleapiclient.errors.HttpError):
return exceptions.UnknownError(
message='Unknown error while making a remote service call: {0}'.format(error),
cause=error)
Expand Down
92 changes: 44 additions & 48 deletions firebase_admin/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@
import re
import time
import os
import requests
import six
from urllib import parse

import requests

from six.moves import urllib
from firebase_admin import _http_client
from firebase_admin import _utils
from firebase_admin import exceptions
Expand Down Expand Up @@ -175,7 +174,7 @@ def delete_model(model_id, app=None):
ml_service.delete_model(model_id)


class Model(object):
class Model:
"""A Firebase ML Model object.

Args:
Expand Down Expand Up @@ -218,8 +217,7 @@ def __eq__(self, other):
if isinstance(other, self.__class__):
# pylint: disable=protected-access
return self._data == other._data and self._model_format == other._model_format
else:
return False
return False

def __ne__(self, other):
return not self.__eq__(other)
Expand Down Expand Up @@ -341,7 +339,7 @@ def as_dict(self, for_upload=False):
return copy


class ModelFormat(object):
class ModelFormat:
"""Abstract base class representing a Model Format such as TFLite."""
def as_dict(self, for_upload=False):
"""Returns a serializable representation of the object."""
Expand Down Expand Up @@ -378,8 +376,7 @@ def __eq__(self, other):
if isinstance(other, self.__class__):
# pylint: disable=protected-access
return self._data == other._data and self._model_source == other._model_source
else:
return False
return False

def __ne__(self, other):
return not self.__eq__(other)
Expand Down Expand Up @@ -409,14 +406,14 @@ def as_dict(self, for_upload=False):
return {'tfliteModel': copy}


class TFLiteModelSource(object):
class TFLiteModelSource:
"""Abstract base class representing a model source for TFLite format models."""
def as_dict(self, for_upload=False):
"""Returns a serializable representation of the object."""
raise NotImplementedError


class _CloudStorageClient(object):
class _CloudStorageClient:
"""Cloud Storage helper class"""

GCS_URI = 'gs://{0}/{1}'
Expand Down Expand Up @@ -475,8 +472,7 @@ def __init__(self, gcs_tflite_uri, app=None):
def __eq__(self, other):
if isinstance(other, self.__class__):
return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access
else:
return False
return False

def __ne__(self, other):
return not self.__eq__(other)
Expand Down Expand Up @@ -517,15 +513,16 @@ def _tf_convert_from_saved_model(saved_model_dir):

@staticmethod
def _tf_convert_from_keras_model(keras_model):
"""Converts the given Keras model into a TF Lite model."""
# Version 1.x conversion function takes a model file. Version 2.x takes the model itself.
if tf.version.VERSION.startswith('1.'):
keras_file = 'firebase_keras_model.h5'
tf.keras.models.save_model(keras_model, keras_file)
converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)
return converter.convert()
else:
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
return converter.convert()

return converter.convert()

@classmethod
def from_saved_model(cls, saved_model_dir, model_file_name='firebase_ml_model.tflite',
Expand Down Expand Up @@ -596,7 +593,7 @@ def as_dict(self, for_upload=False):
return {'gcsTfliteUri': self._gcs_tflite_uri}


class ListModelsPage(object):
class ListModelsPage:
"""Represents a page of models in a firebase project.

Provides methods for traversing the models included in this page, as well as
Expand Down Expand Up @@ -662,7 +659,7 @@ def iterate_all(self):
return _ModelIterator(self)


class _ModelIterator(object):
class _ModelIterator:
"""An iterator that allows iterating over models, one at a time.

This implementation loads a page of models into memory, and iterates on them.
Expand Down Expand Up @@ -730,7 +727,7 @@ def _validate_display_name(display_name):

def _validate_tags(tags):
if not isinstance(tags, list) or not \
all(isinstance(tag, six.string_types) for tag in tags):
all(isinstance(tag, str) for tag in tags):
raise TypeError('Tags must be a list of strings.')
if not all(_TAG_PATTERN.match(tag) for tag in tags):
raise ValueError('Tag format is invalid.')
Expand All @@ -753,7 +750,7 @@ def _validate_model_format(model_format):

def _validate_list_filter(list_filter):
if list_filter is not None:
if not isinstance(list_filter, six.string_types):
if not isinstance(list_filter, str):
raise TypeError('List filter must be a string or None.')


Expand All @@ -769,11 +766,11 @@ def _validate_page_size(page_size):

def _validate_page_token(page_token):
if page_token is not None:
if not isinstance(page_token, six.string_types):
if not isinstance(page_token, str):
raise TypeError('Page token must be a string or None.')


class _MLService(object):
class _MLService:
"""Firebase ML service."""

PROJECT_URL = 'https://mlkit.googleapis.com/v1beta1/projects/{0}/'
Expand Down Expand Up @@ -811,8 +808,7 @@ def _exponential_backoff(self, current_attempt, stop_time):
max_seconds_left = (stop_time - datetime.datetime.now()).total_seconds()
if max_seconds_left < 1: # allow a bit of time for rpc
raise exceptions.DeadlineExceededError('Polling max time exceeded.')
else:
wait_time_seconds = min(wait_time_seconds, max_seconds_left - 1)
wait_time_seconds = min(wait_time_seconds, max_seconds_left - 1)
time.sleep(wait_time_seconds)

def handle_operation(self, operation, wait_for_operation=False, max_time_seconds=None):
Expand All @@ -831,6 +827,7 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds
Raises:
TypeError: if the operation is not a dictionary.
ValueError: If the operation is malformed.
UnknownError: If the server responds with an unexpected response.
err: If the operation exceeds polling attempts or stop_time
"""
if not isinstance(operation, dict):
Expand All @@ -840,31 +837,31 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds
# Operations which are immediately done don't have an operation name
if operation.get('response'):
return operation.get('response')
elif operation.get('error'):
if operation.get('error'):
raise _utils.handle_operation_error(operation.get('error'))
raise exceptions.UnknownError(message='Internal Error: Malformed Operation.')
else:
op_name = operation.get('name')
_, model_id = _validate_and_parse_operation_name(op_name)
current_attempt = 0
start_time = datetime.datetime.now()
stop_time = (None if max_time_seconds is None else
start_time + datetime.timedelta(seconds=max_time_seconds))
while wait_for_operation and not operation.get('done'):
# We just got this operation. Wait before getting another
# so we don't exceed the GetOperation maximum request rate.
self._exponential_backoff(current_attempt, stop_time)
operation = self.get_operation(op_name)
current_attempt += 1

if operation.get('done'):
if operation.get('response'):
return operation.get('response')
elif operation.get('error'):
raise _utils.handle_operation_error(operation.get('error'))

# If the operation is not complete or timed out, return a (locked) model instead
return get_model(model_id).as_dict()

op_name = operation.get('name')
_, model_id = _validate_and_parse_operation_name(op_name)
current_attempt = 0
start_time = datetime.datetime.now()
stop_time = (None if max_time_seconds is None else
start_time + datetime.timedelta(seconds=max_time_seconds))
while wait_for_operation and not operation.get('done'):
# We just got this operation. Wait before getting another
# so we don't exceed the GetOperation maximum request rate.
self._exponential_backoff(current_attempt, stop_time)
operation = self.get_operation(op_name)
current_attempt += 1

if operation.get('done'):
if operation.get('response'):
return operation.get('response')
if operation.get('error'):
raise _utils.handle_operation_error(operation.get('error'))

# If the operation is not complete or timed out, return a (locked) model instead
return get_model(model_id).as_dict()


def create_model(self, model):
Expand Down Expand Up @@ -918,8 +915,7 @@ def list_models(self, list_filter, page_size, page_token):
params['page_token'] = page_token
path = 'models'
if params:
# pylint: disable=too-many-function-args
param_str = urllib.parse.urlencode(sorted(params.items()), True)
param_str = parse.urlencode(sorted(params.items()), True)
path = path + '?' + param_str
try:
return self._client.body('get', url=path)
Expand Down
1 change: 0 additions & 1 deletion integration/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from firebase_admin import credentials



_verify_token_url = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/verifyCustomToken'
_verify_password_url = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/verifyPassword'
_password_reset_url = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/resetPassword'
Expand Down
2 changes: 1 addition & 1 deletion integration/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import shutil
import string
import tempfile
import pytest

import pytest

from firebase_admin import exceptions
from firebase_admin import ml
Expand Down
1 change: 0 additions & 1 deletion tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,6 @@ def test_parse_db_url_errors(self, url, emulator_host):
@pytest.mark.parametrize('url', [
'https://test.firebaseio.com', 'https://test.firebaseio.com/'
])
@pytest.mark.skip(reason='only skip until mlkit branch is synced with master')
def test_valid_db_url(self, url):
firebase_admin.initialize_app(testutils.MockCredential(), {'databaseURL' : url})
ref = db.reference()
Expand Down
18 changes: 10 additions & 8 deletions tests/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
"""Test cases for the firebase_admin.ml module."""

import json

import pytest

import firebase_admin
from firebase_admin import exceptions
from firebase_admin import ml
from tests import testutils


BASE_URL = 'https://mlkit.googleapis.com/v1beta1/'
PROJECT_ID = 'myProject1'
PAGE_TOKEN = 'pageToken'
Expand Down Expand Up @@ -319,7 +321,7 @@ def instrument_ml_service(status=200, payload=None, operations=False, app=None):
session_url, adapter(payload, status, recorder))
return recorder

class _TestStorageClient(object):
class _TestStorageClient:
@staticmethod
def upload(bucket_name, model_file_name, app):
del app # unused variable
Expand All @@ -332,7 +334,7 @@ def sign_uri(gcs_tflite_uri, app):
bucket_name, blob_name = ml._CloudStorageClient._parse_gcs_tflite_uri(gcs_tflite_uri)
return GCS_TFLITE_SIGNED_URI_PATTERN.format(bucket_name, blob_name)

class TestModel(object):
class TestModel:
"""Tests ml.Model class."""
@classmethod
def setup_class(cls):
Expand Down Expand Up @@ -545,7 +547,7 @@ def test_wait_for_unlocked_timeout(self):
assert len(recorder) == 1


class TestCreateModel(object):
class TestCreateModel:
"""Tests ml.create_model."""
@classmethod
def setup_class(cls):
Expand Down Expand Up @@ -641,7 +643,7 @@ def test_invalid_op_name(self, op_name):
check_error(excinfo, ValueError, 'Operation name format is invalid.')


class TestUpdateModel(object):
class TestUpdateModel:
"""Tests ml.update_model."""
@classmethod
def setup_class(cls):
Expand Down Expand Up @@ -733,7 +735,7 @@ def test_invalid_op_name(self, op_name):
check_error(excinfo, ValueError, 'Operation name format is invalid.')


class TestPublishUnpublish(object):
class TestPublishUnpublish:
"""Tests ml.publish_model and ml.unpublish_model."""

PUBLISH_UNPUBLISH_WITH_ARGS = [
Expand Down Expand Up @@ -823,7 +825,7 @@ def test_rpc_error(self, publish_function):
assert len(create_recorder) == 1


class TestGetModel(object):
class TestGetModel:
"""Tests ml.get_model."""
@classmethod
def setup_class(cls):
Expand Down Expand Up @@ -876,7 +878,7 @@ def evaluate():
testutils.run_without_project_id(evaluate)


class TestDeleteModel(object):
class TestDeleteModel:
"""Tests ml.delete_model."""
@classmethod
def setup_class(cls):
Expand Down Expand Up @@ -926,7 +928,7 @@ def evaluate():
testutils.run_without_project_id(evaluate)


class TestListModels(object):
class TestListModels:
"""Tests ml.list_models."""
@classmethod
def setup_class(cls):
Expand Down
6 changes: 0 additions & 6 deletions tests/test_user_mgt.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,12 +310,6 @@ class TestCreateUser:
'PHONE_NUMBER_EXISTS': auth.PhoneNumberAlreadyExistsError,
}

already_exists_errors = {
'DUPLICATE_EMAIL': auth.EmailAlreadyExistsError,
'DUPLICATE_LOCAL_ID': auth.UidAlreadyExistsError,
'PHONE_NUMBER_EXISTS': auth.PhoneNumberAlreadyExistsError,
}

@pytest.mark.parametrize('arg', INVALID_STRINGS[1:] + ['a'*129])
def test_invalid_uid(self, user_mgt_app, arg):
with pytest.raises(ValueError):
Expand Down