diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 6d626eef2..64ee304ce 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -54,6 +54,8 @@ jobs: python -m pip install --upgrade pip pip install -r requirements.txt pip install setuptools wheel + pip install tensorflow + pip install keras - name: Run unit tests run: pytest diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 80a607a8d..f6d09b093 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -183,14 +183,17 @@ Then set up your Firebase/GCP project as follows: Firebase Console. Select the "Sign-in method" tab, and enable the "Email/Password" sign-in method, including the Email link (passwordless sign-in) option. - -3. Enable the IAM API: Go to the +3. Enable the Firebase ML API: Go to the + [Google Developers Console]( + https://console.developers.google.com/apis/api/firebaseml.googleapis.com/overview) + and make sure your project is selected. If the API is not already enabled, click Enable. +4. Enable the IAM API: Go to the [Google Cloud Platform Console](https://console.cloud.google.com) and make sure your Firebase/GCP project is selected. Select "APIs & Services > Dashboard" from the main menu, and click the "ENABLE APIS AND SERVICES" button. Search for and enable the "Identity and Access Management (IAM) API". -4. Grant your service account the 'Firebase Authentication Admin' role. This is +5. Grant your service account the 'Firebase Authentication Admin' role. This is required to ensure that exported user records contain the password hashes of the user accounts: 1. Go to [Google Cloud Platform Console / IAM & admin](https://console.cloud.google.com/iam-admin). diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index 2c4cec868..a5fc8d022 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -59,6 +59,26 @@ } +# See https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto +_RPC_CODE_TO_ERROR_CODE = { + 1: exceptions.CANCELLED, + 2: exceptions.UNKNOWN, + 3: exceptions.INVALID_ARGUMENT, + 4: exceptions.DEADLINE_EXCEEDED, + 5: exceptions.NOT_FOUND, + 6: exceptions.ALREADY_EXISTS, + 7: exceptions.PERMISSION_DENIED, + 8: exceptions.RESOURCE_EXHAUSTED, + 9: exceptions.FAILED_PRECONDITION, + 10: exceptions.ABORTED, + 11: exceptions.OUT_OF_RANGE, + 13: exceptions.INTERNAL, + 14: exceptions.UNAVAILABLE, + 15: exceptions.DATA_LOSS, + 16: exceptions.UNAUTHENTICATED, +} + + def _get_initialized_app(app): """Returns a reference to an initialized App instance.""" if app is None: @@ -75,6 +95,7 @@ def _get_initialized_app(app): ' firebase_admin.App, but given "{0}".'.format(type(app))) + def get_app_service(app, name, initializer): app = _get_initialized_app(app) return app._get_service(name, initializer) # pylint: disable=protected-access @@ -108,6 +129,27 @@ def handle_platform_error_from_requests(error, handle_func=None): return exc if exc else _handle_func_requests(error, message, error_dict) +def handle_operation_error(error): + """Constructs a ``FirebaseError`` from the given operation error. + + Args: + error: An error returned by a long running operation. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code. + """ + if not isinstance(error, dict): + return exceptions.UnknownError( + message='Unknown error while making a remote service call: {0}'.format(error), + cause=error) + + rpc_code = error.get('code') + message = error.get('message') + error_code = _rpc_code_to_error_code(rpc_code) + err_type = _error_code_to_exception_type(error_code) + return err_type(message=message) + + def _handle_func_requests(error, message, error_dict): """Constructs a ``FirebaseError`` from the given GCP error. @@ -264,6 +306,9 @@ def _http_status_to_error_code(status): """Maps an HTTP status to a platform error code.""" return _HTTP_STATUS_TO_ERROR_CODE.get(status, exceptions.UNKNOWN) +def _rpc_code_to_error_code(rpc_code): + """Maps an RPC code to a platform error code.""" + return _RPC_CODE_TO_ERROR_CODE.get(rpc_code, exceptions.UNKNOWN) def _error_code_to_exception_type(code): """Maps a platform error code to an exception type.""" diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py new file mode 100644 index 000000000..db1657839 --- /dev/null +++ b/firebase_admin/ml.py @@ -0,0 +1,938 @@ +# Copyright 2019 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Firebase ML module. + +This module contains functions for creating, updating, getting, listing, +deleting, publishing and unpublishing Firebase ML models. +""" + + +import datetime +import re +import time +import os +from urllib import parse + +import requests + +import firebase_admin +from firebase_admin import _http_client +from firebase_admin import _utils +from firebase_admin import exceptions + +# pylint: disable=import-error,no-name-in-module +try: + from firebase_admin import storage + _GCS_ENABLED = True +except ImportError: + _GCS_ENABLED = False + +# pylint: disable=import-error,no-name-in-module +try: + import tensorflow as tf + _TF_ENABLED = True +except ImportError: + _TF_ENABLED = False + +_ML_ATTRIBUTE = '_ml' +_MAX_PAGE_SIZE = 100 +_MODEL_ID_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') +_DISPLAY_NAME_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') +_TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') +_GCS_TFLITE_URI_PATTERN = re.compile( + r'^gs://(?P[a-z0-9_.-]{3,63})/(?P.+)$') +_RESOURCE_NAME_PATTERN = re.compile( + r'^projects/(?P[a-z0-9-]{6,30})/models/(?P[A-Za-z0-9_-]{1,60})$') +_OPERATION_NAME_PATTERN = re.compile( + r'^projects/(?P[a-z0-9-]{6,30})/operations/[^/]+$') + + +def _get_ml_service(app): + """ Returns an _MLService instance for an App. + + Args: + app: A Firebase App instance (or None to use the default App). + + Returns: + _MLService: An _MLService for the specified App instance. + + Raises: + ValueError: If the app argument is invalid. + """ + return _utils.get_app_service(app, _ML_ATTRIBUTE, _MLService) + + +def create_model(model, app=None): + """Creates a model in Firebase ML. + + Args: + model: An ml.Model to create. + app: A Firebase app instance (or None to use the default app). + + Returns: + Model: The model that was created in Firebase ML. + """ + ml_service = _get_ml_service(app) + return Model.from_dict(ml_service.create_model(model), app=app) + + +def update_model(model, app=None): + """Updates a model in Firebase ML. + + Args: + model: The ml.Model to update. + app: A Firebase app instance (or None to use the default app). + + Returns: + Model: The updated model. + """ + ml_service = _get_ml_service(app) + return Model.from_dict(ml_service.update_model(model), app=app) + + +def publish_model(model_id, app=None): + """Publishes a model in Firebase ML. + + Args: + model_id: The id of the model to publish. + app: A Firebase app instance (or None to use the default app). + + Returns: + Model: The published model. + """ + ml_service = _get_ml_service(app) + return Model.from_dict(ml_service.set_published(model_id, publish=True), app=app) + + +def unpublish_model(model_id, app=None): + """Unpublishes a model in Firebase ML. + + Args: + model_id: The id of the model to unpublish. + app: A Firebase app instance (or None to use the default app). + + Returns: + Model: The unpublished model. + """ + ml_service = _get_ml_service(app) + return Model.from_dict(ml_service.set_published(model_id, publish=False), app=app) + + +def get_model(model_id, app=None): + """Gets a model from Firebase ML. + + Args: + model_id: The id of the model to get. + app: A Firebase app instance (or None to use the default app). + + Returns: + Model: The requested model. + """ + ml_service = _get_ml_service(app) + return Model.from_dict(ml_service.get_model(model_id), app=app) + + +def list_models(list_filter=None, page_size=None, page_token=None, app=None): + """Lists models from Firebase ML. + + Args: + list_filter: a list filter string such as "tags:'tag_1'". None will return all models. + page_size: A number between 1 and 100 inclusive that specifies the maximum + number of models to return per page. None for default. + page_token: A next page token returned from a previous page of results. None + for first page of results. + app: A Firebase app instance (or None to use the default app). + + Returns: + ListModelsPage: A (filtered) list of models. + """ + ml_service = _get_ml_service(app) + return ListModelsPage( + ml_service.list_models, list_filter, page_size, page_token, app=app) + + +def delete_model(model_id, app=None): + """Deletes a model from Firebase ML. + + Args: + model_id: The id of the model you wish to delete. + app: A Firebase app instance (or None to use the default app). + """ + ml_service = _get_ml_service(app) + ml_service.delete_model(model_id) + + +class Model: + """A Firebase ML Model object. + + Args: + display_name: The display name of your model - used to identify your model in code. + tags: Optional list of strings associated with your model. Can be used in list queries. + model_format: A subclass of ModelFormat. (e.g. TFLiteFormat) Specifies the model details. + """ + def __init__(self, display_name=None, tags=None, model_format=None): + self._app = None # Only needed for wait_for_unlo + self._data = {} + self._model_format = None + + if display_name is not None: + self.display_name = display_name + if tags is not None: + self.tags = tags + if model_format is not None: + self.model_format = model_format + + @classmethod + def from_dict(cls, data, app=None): + """Create an instance of the object from a dict.""" + data_copy = dict(data) + tflite_format = None + tflite_format_data = data_copy.pop('tfliteModel', None) + data_copy.pop('@type', None) # Returned by Operations. (Not needed) + if tflite_format_data: + tflite_format = TFLiteFormat.from_dict(tflite_format_data) + model = Model(model_format=tflite_format) + model._data = data_copy # pylint: disable=protected-access + model._app = app # pylint: disable=protected-access + return model + + def _update_from_dict(self, data): + copy = Model.from_dict(data) + self.model_format = copy.model_format + self._data = copy._data # pylint: disable=protected-access + + def __eq__(self, other): + if isinstance(other, self.__class__): + # pylint: disable=protected-access + return self._data == other._data and self._model_format == other._model_format + return False + + def __ne__(self, other): + return not self.__eq__(other) + + @property + def model_id(self): + """The model's ID, unique to the project.""" + if not self._data.get('name'): + return None + _, model_id = _validate_and_parse_name(self._data.get('name')) + return model_id + + @property + def display_name(self): + """The model's display name, used to refer to the model in code and in + the Firebase console.""" + return self._data.get('displayName') + + @display_name.setter + def display_name(self, display_name): + self._data['displayName'] = _validate_display_name(display_name) + return self + + @staticmethod + def _convert_to_millis(date_string): + if not date_string: + return None + format_str = '%Y-%m-%dT%H:%M:%S.%fZ' + epoch = datetime.datetime.utcfromtimestamp(0) + datetime_object = datetime.datetime.strptime(date_string, format_str) + millis = int((datetime_object - epoch).total_seconds() * 1000) + return millis + + @property + def create_time(self): + """The time the model was created.""" + return Model._convert_to_millis(self._data.get('createTime', None)) + + @property + def update_time(self): + """The time the model was last updated.""" + return Model._convert_to_millis(self._data.get('updateTime', None)) + + @property + def validation_error(self): + """Validation error message.""" + return self._data.get('state', {}).get('validationError', {}).get('message') + + @property + def published(self): + """True if the model is published and available for clients to + download.""" + return bool(self._data.get('state', {}).get('published')) + + @property + def etag(self): + """The entity tag (ETag) of the model resource.""" + return self._data.get('etag') + + @property + def model_hash(self): + """SHA256 hash of the model binary.""" + return self._data.get('modelHash') + + @property + def tags(self): + """Tag strings, used for filtering query results.""" + return self._data.get('tags') + + @tags.setter + def tags(self, tags): + self._data['tags'] = _validate_tags(tags) + return self + + @property + def locked(self): + """True if the Model object is locked by an active operation.""" + return bool(self._data.get('activeOperations') and + len(self._data.get('activeOperations')) > 0) + + def wait_for_unlocked(self, max_time_seconds=None): + """Waits for the model to be unlocked. (All active operations complete) + + Args: + max_time_seconds: The maximum number of seconds to wait for the model to unlock. + (None for no limit) + + Raises: + exceptions.DeadlineExceeded: If max_time_seconds passed and the model is still locked. + """ + if not self.locked: + return + ml_service = _get_ml_service(self._app) + op_name = self._data.get('activeOperations')[0].get('name') + model_dict = ml_service.handle_operation( + ml_service.get_operation(op_name), + wait_for_operation=True, + max_time_seconds=max_time_seconds) + self._update_from_dict(model_dict) + + @property + def model_format(self): + """The model's ``ModelFormat`` object, which represents the model's + format and storage location.""" + return self._model_format + + @model_format.setter + def model_format(self, model_format): + if model_format is not None: + _validate_model_format(model_format) + self._model_format = model_format #Can be None + return self + + def as_dict(self, for_upload=False): + """Returns a serializable representation of the object.""" + copy = dict(self._data) + if self._model_format: + copy.update(self._model_format.as_dict(for_upload=for_upload)) + return copy + + +class ModelFormat: + """Abstract base class representing a Model Format such as TFLite.""" + def as_dict(self, for_upload=False): + """Returns a serializable representation of the object.""" + raise NotImplementedError + + +class TFLiteFormat(ModelFormat): + """Model format representing a TFLite model. + + Args: + model_source: A TFLiteModelSource sub class. Specifies the details of the model source. + """ + def __init__(self, model_source=None): + self._data = {} + self._model_source = None + + if model_source is not None: + self.model_source = model_source + + @classmethod + def from_dict(cls, data): + """Create an instance of the object from a dict.""" + data_copy = dict(data) + model_source = None + gcs_tflite_uri = data_copy.pop('gcsTfliteUri', None) + if gcs_tflite_uri: + model_source = TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri) + tflite_format = TFLiteFormat(model_source=model_source) + tflite_format._data = data_copy # pylint: disable=protected-access + return tflite_format + + + def __eq__(self, other): + if isinstance(other, self.__class__): + # pylint: disable=protected-access + return self._data == other._data and self._model_source == other._model_source + return False + + def __ne__(self, other): + return not self.__eq__(other) + + @property + def model_source(self): + """The TF Lite model's location.""" + return self._model_source + + @model_source.setter + def model_source(self, model_source): + if model_source is not None: + if not isinstance(model_source, TFLiteModelSource): + raise TypeError('Model source must be a TFLiteModelSource object.') + self._model_source = model_source # Can be None + + @property + def size_bytes(self): + """The size in bytes of the TF Lite model.""" + return self._data.get('sizeBytes') + + def as_dict(self, for_upload=False): + """Returns a serializable representation of the object.""" + copy = dict(self._data) + if self._model_source: + copy.update(self._model_source.as_dict(for_upload=for_upload)) + return {'tfliteModel': copy} + + +class TFLiteModelSource: + """Abstract base class representing a model source for TFLite format models.""" + def as_dict(self, for_upload=False): + """Returns a serializable representation of the object.""" + raise NotImplementedError + + +class _CloudStorageClient: + """Cloud Storage helper class""" + + GCS_URI = 'gs://{0}/{1}' + BLOB_NAME = 'Firebase/ML/Models/{0}' + + @staticmethod + def _assert_gcs_enabled(): + if not _GCS_ENABLED: + raise ImportError('Failed to import the Cloud Storage library for Python. Make sure ' + 'to install the "google-cloud-storage" module.') + + @staticmethod + def _parse_gcs_tflite_uri(uri): + # GCS Bucket naming rules are complex. The regex is not comprehensive. + # See https://cloud.google.com/storage/docs/naming for full details. + matcher = _GCS_TFLITE_URI_PATTERN.match(uri) + if not matcher: + raise ValueError('GCS TFLite URI format is invalid.') + return matcher.group('bucket_name'), matcher.group('blob_name') + + @staticmethod + def upload(bucket_name, model_file_name, app): + """Upload a model file to the specified Storage bucket.""" + _CloudStorageClient._assert_gcs_enabled() + + file_name = os.path.basename(model_file_name) + bucket = storage.bucket(bucket_name, app=app) + blob_name = _CloudStorageClient.BLOB_NAME.format(file_name) + blob = bucket.blob(blob_name) + blob.upload_from_filename(model_file_name) + return _CloudStorageClient.GCS_URI.format(bucket.name, blob_name) + + @staticmethod + def sign_uri(gcs_tflite_uri, app): + """Makes the gcs_tflite_uri readable for GET for 10 minutes via signed_uri.""" + _CloudStorageClient._assert_gcs_enabled() + bucket_name, blob_name = _CloudStorageClient._parse_gcs_tflite_uri(gcs_tflite_uri) + bucket = storage.bucket(bucket_name, app=app) + blob = bucket.blob(blob_name) + return blob.generate_signed_url( + version='v4', + expiration=datetime.timedelta(minutes=10), + method='GET' + ) + + +class TFLiteGCSModelSource(TFLiteModelSource): + """TFLite model source representing a tflite model file stored in GCS.""" + + _STORAGE_CLIENT = _CloudStorageClient() + + def __init__(self, gcs_tflite_uri, app=None): + self._app = app + self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access + return False + + def __ne__(self, other): + return not self.__eq__(other) + + @classmethod + def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None): + """Uploads the model file to an existing Google Cloud Storage bucket. + + Args: + model_file_name: The name of the model file. + bucket_name: The name of an existing bucket. None to use the default bucket configured + in the app. + app: A Firebase app instance (or None to use the default app). + + Returns: + TFLiteGCSModelSource: The source created from the model_file + + Raises: + ImportError: If the Cloud Storage Library has not been installed. + """ + gcs_uri = TFLiteGCSModelSource._STORAGE_CLIENT.upload(bucket_name, model_file_name, app) + return TFLiteGCSModelSource(gcs_tflite_uri=gcs_uri, app=app) + + @staticmethod + def _assert_tf_enabled(): + if not _TF_ENABLED: + raise ImportError('Failed to import the tensorflow library for Python. Make sure ' + 'to install the tensorflow module.') + if not tf.version.VERSION.startswith('1.') and not tf.version.VERSION.startswith('2.'): + raise ImportError('Expected tensorflow version 1.x or 2.x, but found {0}' + .format(tf.version.VERSION)) + + @staticmethod + def _tf_convert_from_saved_model(saved_model_dir): + # Same for both v1.x and v2.x + converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) + return converter.convert() + + @staticmethod + def _tf_convert_from_keras_model(keras_model): + """Converts the given Keras model into a TF Lite model.""" + # Version 1.x conversion function takes a model file. Version 2.x takes the model itself. + if tf.version.VERSION.startswith('1.'): + keras_file = 'firebase_keras_model.h5' + tf.keras.models.save_model(keras_model, keras_file) + converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) + else: + converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) + + return converter.convert() + + @classmethod + def from_saved_model(cls, saved_model_dir, model_file_name='firebase_ml_model.tflite', + bucket_name=None, app=None): + """Creates a Tensor Flow Lite model from the saved model, and uploads the model to GCS. + + Args: + saved_model_dir: The saved model directory. + model_file_name: The name that the tflite model will be saved as in Cloud Storage. + bucket_name: The name of an existing bucket. None to use the default bucket configured + in the app. + app: Optional. A Firebase app instance (or None to use the default app) + + Returns: + TFLiteGCSModelSource: The source created from the saved_model_dir + + Raises: + ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed. + """ + TFLiteGCSModelSource._assert_tf_enabled() + tflite_model = TFLiteGCSModelSource._tf_convert_from_saved_model(saved_model_dir) + with open(model_file_name, 'wb') as model_file: + model_file.write(tflite_model) + return TFLiteGCSModelSource.from_tflite_model_file(model_file_name, bucket_name, app) + + @classmethod + def from_keras_model(cls, keras_model, model_file_name='firebase_ml_model.tflite', + bucket_name=None, app=None): + """Creates a Tensor Flow Lite model from the keras model, and uploads the model to GCS. + + Args: + keras_model: A tf.keras model. + model_file_name: The name that the tflite model will be saved as in Cloud Storage. + bucket_name: The name of an existing bucket. None to use the default bucket configured + in the app. + app: Optional. A Firebase app instance (or None to use the default app) + + Returns: + TFLiteGCSModelSource: The source created from the keras_model + + Raises: + ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed. + """ + TFLiteGCSModelSource._assert_tf_enabled() + tflite_model = TFLiteGCSModelSource._tf_convert_from_keras_model(keras_model) + with open(model_file_name, 'wb') as model_file: + model_file.write(tflite_model) + return TFLiteGCSModelSource.from_tflite_model_file(model_file_name, bucket_name, app) + + @property + def gcs_tflite_uri(self): + """URI of the model file in Cloud Storage.""" + return self._gcs_tflite_uri + + @gcs_tflite_uri.setter + def gcs_tflite_uri(self, gcs_tflite_uri): + self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) + + def _get_signed_gcs_tflite_uri(self): + """Signs the GCS uri, so the model file can be uploaded to Firebase ML and verified.""" + return TFLiteGCSModelSource._STORAGE_CLIENT.sign_uri(self._gcs_tflite_uri, self._app) + + def as_dict(self, for_upload=False): + """Returns a serializable representation of the object.""" + if for_upload: + return {'gcsTfliteUri': self._get_signed_gcs_tflite_uri()} + + return {'gcsTfliteUri': self._gcs_tflite_uri} + + +class ListModelsPage: + """Represents a page of models in a firebase project. + + Provides methods for traversing the models included in this page, as well as + retrieving subsequent pages of models. The iterator returned by + ``iterate_all()`` can be used to iterate through all the models in the + Firebase project starting from this page. + """ + def __init__(self, list_models_func, list_filter, page_size, page_token, app): + self._list_models_func = list_models_func + self._list_filter = list_filter + self._page_size = page_size + self._page_token = page_token + self._app = app + self._list_response = list_models_func(list_filter, page_size, page_token) + + @property + def models(self): + """A list of Models from this page.""" + return [ + Model.from_dict(model, app=self._app) for model in self._list_response.get('models', []) + ] + + @property + def list_filter(self): + """The filter string used to filter the models.""" + return self._list_filter + + @property + def next_page_token(self): + """Token identifying the next page of results.""" + return self._list_response.get('nextPageToken', '') + + @property + def has_next_page(self): + """True if more pages are available.""" + return bool(self.next_page_token) + + def get_next_page(self): + """Retrieves the next page of models if available. + + Returns: + ListModelsPage: Next page of models, or None if this is the last page. + """ + if self.has_next_page: + return ListModelsPage( + self._list_models_func, + self._list_filter, + self._page_size, + self.next_page_token, + self._app) + return None + + def iterate_all(self): + """Retrieves an iterator for Models. + + Returned iterator will iterate through all the models in the Firebase + project starting from this page. The iterator will never buffer more than + one page of models in memory at a time. + + Returns: + iterator: An iterator of Model instances. + """ + return _ModelIterator(self) + + +class _ModelIterator: + """An iterator that allows iterating over models, one at a time. + + This implementation loads a page of models into memory, and iterates on them. + When the whole page has been traversed, it loads another page. This class + never keeps more than one page of entries in memory. + """ + def __init__(self, current_page): + if not isinstance(current_page, ListModelsPage): + raise TypeError('Current page must be a ListModelsPage') + self._current_page = current_page + self._index = 0 + + def next(self): + if self._index == len(self._current_page.models): + if self._current_page.has_next_page: + self._current_page = self._current_page.get_next_page() + self._index = 0 + if self._index < len(self._current_page.models): + result = self._current_page.models[self._index] + self._index += 1 + return result + raise StopIteration + + def __next__(self): + return self.next() + + def __iter__(self): + return self + + +def _validate_and_parse_name(name): + # The resource name is added automatically from API call responses. + # The only way it could be invalid is if someone tries to + # create a model from a dictionary manually and does it incorrectly. + matcher = _RESOURCE_NAME_PATTERN.match(name) + if not matcher: + raise ValueError('Model resource name format is invalid.') + return matcher.group('project_id'), matcher.group('model_id') + + +def _validate_model(model, update_mask=None): + if not isinstance(model, Model): + raise TypeError('Model must be an ml.Model.') + if update_mask is None and not model.display_name: + raise ValueError('Model must have a display name.') + + +def _validate_model_id(model_id): + if not _MODEL_ID_PATTERN.match(model_id): + raise ValueError('Model ID format is invalid.') + + +def _validate_operation_name(op_name): + if not _OPERATION_NAME_PATTERN.match(op_name): + raise ValueError('Operation name format is invalid.') + return op_name + + +def _validate_display_name(display_name): + if not _DISPLAY_NAME_PATTERN.match(display_name): + raise ValueError('Display name format is invalid.') + return display_name + + +def _validate_tags(tags): + if not isinstance(tags, list) or not \ + all(isinstance(tag, str) for tag in tags): + raise TypeError('Tags must be a list of strings.') + if not all(_TAG_PATTERN.match(tag) for tag in tags): + raise ValueError('Tag format is invalid.') + return tags + + +def _validate_gcs_tflite_uri(uri): + # GCS Bucket naming rules are complex. The regex is not comprehensive. + # See https://cloud.google.com/storage/docs/naming for full details. + if not _GCS_TFLITE_URI_PATTERN.match(uri): + raise ValueError('GCS TFLite URI format is invalid.') + return uri + + +def _validate_model_format(model_format): + if not isinstance(model_format, ModelFormat): + raise TypeError('Model format must be a ModelFormat object.') + return model_format + + +def _validate_list_filter(list_filter): + if list_filter is not None: + if not isinstance(list_filter, str): + raise TypeError('List filter must be a string or None.') + + +def _validate_page_size(page_size): + if page_size is not None: + if type(page_size) is not int: # pylint: disable=unidiomatic-typecheck + # Specifically type() to disallow boolean which is a subtype of int + raise TypeError('Page size must be a number or None.') + if page_size < 1 or page_size > _MAX_PAGE_SIZE: + raise ValueError('Page size must be a positive integer between ' + '1 and {0}'.format(_MAX_PAGE_SIZE)) + + +def _validate_page_token(page_token): + if page_token is not None: + if not isinstance(page_token, str): + raise TypeError('Page token must be a string or None.') + + +class _MLService: + """Firebase ML service.""" + + PROJECT_URL = 'https://firebaseml.googleapis.com/v1beta2/projects/{0}/' + OPERATION_URL = 'https://firebaseml.googleapis.com/v1beta2/' + POLL_EXPONENTIAL_BACKOFF_FACTOR = 1.5 + POLL_BASE_WAIT_TIME_SECONDS = 3 + + def __init__(self, app): + self._project_id = app.project_id + if not self._project_id: + raise ValueError( + 'Project ID is required to access ML service. Either set the ' + 'projectId option, or use service account credentials.') + self._project_url = _MLService.PROJECT_URL.format(self._project_id) + ml_headers = { + 'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__), + } + self._client = _http_client.JsonHttpClient( + credential=app.credential.get_credential(), + headers=ml_headers, + base_url=self._project_url) + self._operation_client = _http_client.JsonHttpClient( + credential=app.credential.get_credential(), + headers=ml_headers, + base_url=_MLService.OPERATION_URL) + + def get_operation(self, op_name): + _validate_operation_name(op_name) + try: + return self._operation_client.body('get', url=op_name) + except requests.exceptions.RequestException as error: + raise _utils.handle_platform_error_from_requests(error) + + def _exponential_backoff(self, current_attempt, stop_time): + """Sleeps for the appropriate amount of time. Or throws deadline exceeded.""" + delay_factor = pow(_MLService.POLL_EXPONENTIAL_BACKOFF_FACTOR, current_attempt) + wait_time_seconds = delay_factor * _MLService.POLL_BASE_WAIT_TIME_SECONDS + + if stop_time is not None: + max_seconds_left = (stop_time - datetime.datetime.now()).total_seconds() + if max_seconds_left < 1: # allow a bit of time for rpc + raise exceptions.DeadlineExceededError('Polling max time exceeded.') + wait_time_seconds = min(wait_time_seconds, max_seconds_left - 1) + time.sleep(wait_time_seconds) + + def handle_operation(self, operation, wait_for_operation=False, max_time_seconds=None): + """Handles long running operations. + + Args: + operation: The operation to handle. + wait_for_operation: Should we allow polling for the operation to complete. + If no polling is requested, a locked model will be returned instead. + max_time_seconds: The maximum seconds to try polling for operation complete. + (None for no limit) + + Returns: + dict: A dictionary of the returned model properties. + + Raises: + TypeError: if the operation is not a dictionary. + ValueError: If the operation is malformed. + UnknownError: If the server responds with an unexpected response. + err: If the operation exceeds polling attempts or stop_time + """ + if not isinstance(operation, dict): + raise TypeError('Operation must be a dictionary.') + + if operation.get('done'): + # Operations which are immediately done don't have an operation name + if operation.get('response'): + return operation.get('response') + if operation.get('error'): + raise _utils.handle_operation_error(operation.get('error')) + raise exceptions.UnknownError(message='Internal Error: Malformed Operation.') + + op_name = _validate_operation_name(operation.get('name')) + metadata = operation.get('metadata', {}) + metadata_type = metadata.get('@type', '') + if not metadata_type.endswith('ModelOperationMetadata'): + raise TypeError('Unknown type of operation metadata.') + _, model_id = _validate_and_parse_name(metadata.get('name')) + current_attempt = 0 + start_time = datetime.datetime.now() + stop_time = (None if max_time_seconds is None else + start_time + datetime.timedelta(seconds=max_time_seconds)) + while wait_for_operation and not operation.get('done'): + # We just got this operation. Wait before getting another + # so we don't exceed the GetOperation maximum request rate. + self._exponential_backoff(current_attempt, stop_time) + operation = self.get_operation(op_name) + current_attempt += 1 + + if operation.get('done'): + if operation.get('response'): + return operation.get('response') + if operation.get('error'): + raise _utils.handle_operation_error(operation.get('error')) + + # If the operation is not complete or timed out, return a (locked) model instead + return get_model(model_id).as_dict() + + + def create_model(self, model): + _validate_model(model) + try: + return self.handle_operation( + self._client.body('post', url='models', json=model.as_dict(for_upload=True))) + except requests.exceptions.RequestException as error: + raise _utils.handle_platform_error_from_requests(error) + + def update_model(self, model, update_mask=None): + _validate_model(model, update_mask) + path = 'models/{0}'.format(model.model_id) + if update_mask is not None: + path = path + '?updateMask={0}'.format(update_mask) + try: + return self.handle_operation( + self._client.body('patch', url=path, json=model.as_dict(for_upload=True))) + except requests.exceptions.RequestException as error: + raise _utils.handle_platform_error_from_requests(error) + + def set_published(self, model_id, publish): + _validate_model_id(model_id) + model_name = 'projects/{0}/models/{1}'.format(self._project_id, model_id) + model = Model.from_dict({ + 'name': model_name, + 'state': { + 'published': publish + } + }) + return self.update_model(model, update_mask='state.published') + + def get_model(self, model_id): + _validate_model_id(model_id) + try: + return self._client.body('get', url='models/{0}'.format(model_id)) + except requests.exceptions.RequestException as error: + raise _utils.handle_platform_error_from_requests(error) + + def list_models(self, list_filter, page_size, page_token): + """ lists Firebase ML models.""" + _validate_list_filter(list_filter) + _validate_page_size(page_size) + _validate_page_token(page_token) + params = {} + if list_filter: + params['filter'] = list_filter + if page_size: + params['page_size'] = page_size + if page_token: + params['page_token'] = page_token + path = 'models' + if params: + param_str = parse.urlencode(sorted(params.items()), True) + path = path + '?' + param_str + try: + return self._client.body('get', url=path) + except requests.exceptions.RequestException as error: + raise _utils.handle_platform_error_from_requests(error) + + def delete_model(self, model_id): + _validate_model_id(model_id) + try: + self._client.body('delete', url='models/{0}'.format(model_id)) + except requests.exceptions.RequestException as error: + raise _utils.handle_platform_error_from_requests(error) diff --git a/integration/test_ml.py b/integration/test_ml.py new file mode 100644 index 000000000..be791d8fa --- /dev/null +++ b/integration/test_ml.py @@ -0,0 +1,373 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for firebase_admin.ml module.""" +import os +import random +import re +import shutil +import string +import tempfile + +import pytest + +from firebase_admin import exceptions +from firebase_admin import ml +from tests import testutils + + +# pylint: disable=import-error,no-name-in-module +try: + import tensorflow as tf + _TF_ENABLED = True +except ImportError: + _TF_ENABLED = False + + +def _random_identifier(prefix): + #pylint: disable=unused-variable + suffix = ''.join([random.choice(string.ascii_letters + string.digits) for n in range(8)]) + return '{0}_{1}'.format(prefix, suffix) + + +NAME_ONLY_ARGS = { + 'display_name': _random_identifier('TestModel123_') +} +NAME_ONLY_ARGS_UPDATED = { + 'display_name': _random_identifier('TestModel123_updated_') +} +NAME_AND_TAGS_ARGS = { + 'display_name': _random_identifier('TestModel123_tags_'), + 'tags': ['test_tag123'] +} +FULL_MODEL_ARGS = { + 'display_name': _random_identifier('TestModel123_full_'), + 'tags': ['test_tag567'], + 'file_name': 'model1.tflite' +} +INVALID_FULL_MODEL_ARGS = { + 'display_name': _random_identifier('TestModel123_invalid_full_'), + 'tags': ['test_tag890'], + 'file_name': 'invalid_model.tflite' +} + + +@pytest.fixture +def firebase_model(request): + args = request.param + tflite_format = None + file_name = args.get('file_name') + if file_name: + file_path = testutils.resource_filename(file_name) + source = ml.TFLiteGCSModelSource.from_tflite_model_file(file_path) + tflite_format = ml.TFLiteFormat(model_source=source) + + ml_model = ml.Model( + display_name=args.get('display_name'), + tags=args.get('tags'), + model_format=tflite_format) + model = ml.create_model(model=ml_model) + yield model + _clean_up_model(model) + + +@pytest.fixture +def model_list(): + ml_model_1 = ml.Model(display_name=_random_identifier('TestModel123_list1_')) + model_1 = ml.create_model(model=ml_model_1) + + ml_model_2 = ml.Model(display_name=_random_identifier('TestModel123_list2_'), + tags=['test_tag123']) + model_2 = ml.create_model(model=ml_model_2) + + yield [model_1, model_2] + + _clean_up_model(model_1) + _clean_up_model(model_2) + + +def _clean_up_model(model): + try: + # Try to delete the model. + # Some tests delete the model as part of the test. + ml.delete_model(model.model_id) + except exceptions.NotFoundError: + pass + + +# For rpc errors +def check_firebase_error(excinfo, status, msg): + err = excinfo.value + assert isinstance(err, exceptions.FirebaseError) + assert err.cause is not None + assert err.http_response is not None + assert err.http_response.status_code == status + assert str(err) == msg + + +# For operation errors +def check_operation_error(excinfo, msg): + err = excinfo.value + assert isinstance(err, exceptions.FirebaseError) + assert str(err) == msg + + +def check_model(model, args): + assert model.display_name == args.get('display_name') + assert model.tags == args.get('tags') + assert model.model_id is not None + assert model.create_time is not None + assert model.update_time is not None + assert model.locked is False + assert model.etag is not None + + +def check_model_format(model, has_model_format=False, validation_error=None): + if has_model_format: + assert model.validation_error == validation_error + assert model.published is False + assert model.model_format.model_source.gcs_tflite_uri.startswith('gs://') + if validation_error: + assert model.model_format.size_bytes is None + assert model.model_hash is None + else: + assert model.model_format.size_bytes is not None + assert model.model_hash is not None + else: + assert model.model_format is None + assert model.validation_error == 'No model file has been uploaded.' + assert model.published is False + assert model.model_hash is None + + +@pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True) +def test_create_simple_model(firebase_model): + check_model(firebase_model, NAME_AND_TAGS_ARGS) + check_model_format(firebase_model) + + +@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) +def test_create_full_model(firebase_model): + check_model(firebase_model, FULL_MODEL_ARGS) + check_model_format(firebase_model, True) + + +@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) +def test_create_already_existing_fails(firebase_model): + with pytest.raises(exceptions.AlreadyExistsError) as excinfo: + ml.create_model(model=firebase_model) + check_operation_error( + excinfo, + 'Model \'{0}\' already exists'.format(firebase_model.display_name)) + + +@pytest.mark.parametrize('firebase_model', [INVALID_FULL_MODEL_ARGS], indirect=True) +def test_create_invalid_model(firebase_model): + check_model(firebase_model, INVALID_FULL_MODEL_ARGS) + check_model_format(firebase_model, True, 'Invalid flatbuffer format') + + +@pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True) +def test_get_model(firebase_model): + get_model = ml.get_model(firebase_model.model_id) + check_model(get_model, NAME_AND_TAGS_ARGS) + check_model_format(get_model) + + +@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) +def test_get_non_existing_model(firebase_model): + # Get a valid model_id that no longer exists + ml.delete_model(firebase_model.model_id) + + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.get_model(firebase_model.model_id) + check_firebase_error(excinfo, 404, 'Requested entity was not found.') + + +@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) +def test_update_model(firebase_model): + new_model_name = NAME_ONLY_ARGS_UPDATED.get('display_name') + firebase_model.display_name = new_model_name + updated_model = ml.update_model(firebase_model) + check_model(updated_model, NAME_ONLY_ARGS_UPDATED) + check_model_format(updated_model) + + # Second call with same model does not cause error + updated_model2 = ml.update_model(updated_model) + check_model(updated_model2, NAME_ONLY_ARGS_UPDATED) + check_model_format(updated_model2) + + +@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) +def test_update_non_existing_model(firebase_model): + ml.delete_model(firebase_model.model_id) + + firebase_model.tags = ['tag987'] + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.update_model(firebase_model) + check_operation_error( + excinfo, + 'Model \'{0}\' was not found'.format(firebase_model.as_dict().get('name'))) + + +@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) +def test_publish_unpublish_model(firebase_model): + assert firebase_model.published is False + + published_model = ml.publish_model(firebase_model.model_id) + assert published_model.published is True + + unpublished_model = ml.unpublish_model(published_model.model_id) + assert unpublished_model.published is False + + +@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) +def test_publish_invalid_fails(firebase_model): + assert firebase_model.validation_error is not None + + with pytest.raises(exceptions.FailedPreconditionError) as excinfo: + ml.publish_model(firebase_model.model_id) + check_operation_error( + excinfo, + 'Cannot publish a model that is not verified.') + + +@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) +def test_publish_unpublish_non_existing_model(firebase_model): + ml.delete_model(firebase_model.model_id) + + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.publish_model(firebase_model.model_id) + check_operation_error( + excinfo, + 'Model \'{0}\' was not found'.format(firebase_model.as_dict().get('name'))) + + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.unpublish_model(firebase_model.model_id) + check_operation_error( + excinfo, + 'Model \'{0}\' was not found'.format(firebase_model.as_dict().get('name'))) + + +def test_list_models(model_list): + filter_str = 'displayName={0} OR tags:{1}'.format( + model_list[0].display_name, model_list[1].tags[0]) + + all_models = ml.list_models(list_filter=filter_str) + all_model_ids = [mdl.model_id for mdl in all_models.iterate_all()] + for mdl in model_list: + assert mdl.model_id in all_model_ids + + +def test_list_models_invalid_filter(): + invalid_filter = 'InvalidFilterParam=123' + + with pytest.raises(exceptions.InvalidArgumentError) as excinfo: + ml.list_models(list_filter=invalid_filter) + check_firebase_error(excinfo, 400, 'Request contains an invalid argument.') + + +@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) +def test_delete_model(firebase_model): + ml.delete_model(firebase_model.model_id) + + # Second delete of same model will fail + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.delete_model(firebase_model.model_id) + check_firebase_error(excinfo, 404, 'Requested entity was not found.') + + +# Test tensor flow conversion functions if tensor flow is enabled. +#'pip install tensorflow' in the environment if you want _TF_ENABLED = True +#'pip install tensorflow==2.0.0b' for version 2 etc. + + +def _clean_up_directory(save_dir): + if save_dir.startswith(tempfile.gettempdir()) and os.path.exists(save_dir): + shutil.rmtree(save_dir) + + +@pytest.fixture +def keras_model(): + assert _TF_ENABLED + x_array = [-1, 0, 1, 2, 3, 4] + y_array = [-3, -1, 1, 3, 5, 7] + model = tf.keras.models.Sequential( + [tf.keras.layers.Dense(units=1, input_shape=[1])]) + model.compile(optimizer='sgd', loss='mean_squared_error') + model.fit(x_array, y_array, epochs=3) + return model + + +@pytest.fixture +def saved_model_dir(keras_model): + assert _TF_ENABLED + # Make a new parent directory. The child directory must not exist yet. + # The child directory gets created by tf. If it exists, the tf call fails. + parent = tempfile.mkdtemp() + save_dir = os.path.join(parent, 'child') + + # different versions have different model conversion capability + # pick something that works for each version + if tf.version.VERSION.startswith('1.'): + tf.reset_default_graph() + x_var = tf.placeholder(tf.float32, (None, 3), name="x") + y_var = tf.multiply(x_var, x_var, name="y") + with tf.Session() as sess: + tf.saved_model.simple_save(sess, save_dir, {"x": x_var}, {"y": y_var}) + else: + # If it's not version 1.x or version 2.x we need to update the test. + assert tf.version.VERSION.startswith('2.') + tf.saved_model.save(keras_model, save_dir) + yield save_dir + _clean_up_directory(parent) + + +@pytest.mark.skipif(not _TF_ENABLED, reason='Tensor flow is required for this test.') +def test_from_keras_model(keras_model): + source = ml.TFLiteGCSModelSource.from_keras_model(keras_model, 'model2.tflite') + assert re.search( + '^gs://.*/Firebase/ML/Models/model2.tflite$', + source.gcs_tflite_uri) is not None + + # Validate the conversion by creating a model + model_format = ml.TFLiteFormat(model_source=source) + model = ml.Model(display_name=_random_identifier('KerasModel_'), model_format=model_format) + created_model = ml.create_model(model) + + try: + check_model(created_model, {'display_name': model.display_name}) + check_model_format(created_model, True) + finally: + _clean_up_model(created_model) + + +@pytest.mark.skipif(not _TF_ENABLED, reason='Tensor flow is required for this test.') +def test_from_saved_model(saved_model_dir): + # Test the conversion helper + source = ml.TFLiteGCSModelSource.from_saved_model(saved_model_dir, 'model3.tflite') + assert re.search( + '^gs://.*/Firebase/ML/Models/model3.tflite$', + source.gcs_tflite_uri) is not None + + # Validate the conversion by creating a model + model_format = ml.TFLiteFormat(model_source=source) + model = ml.Model(display_name=_random_identifier('SavedModel_'), model_format=model_format) + created_model = ml.create_model(model) + + try: + assert created_model.model_id is not None + assert created_model.validation_error is None + finally: + _clean_up_model(created_model) diff --git a/tests/data/invalid_model.tflite b/tests/data/invalid_model.tflite new file mode 100644 index 000000000..d8482f436 --- /dev/null +++ b/tests/data/invalid_model.tflite @@ -0,0 +1 @@ +This is not a tflite file. diff --git a/tests/data/model1.tflite b/tests/data/model1.tflite new file mode 100644 index 000000000..c4b71b7a2 Binary files /dev/null and b/tests/data/model1.tflite differ diff --git a/tests/test_ml.py b/tests/test_ml.py new file mode 100644 index 000000000..10b0441db --- /dev/null +++ b/tests/test_ml.py @@ -0,0 +1,1113 @@ +# Copyright 2019 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test cases for the firebase_admin.ml module.""" + +import json + +import pytest + +import firebase_admin +from firebase_admin import exceptions +from firebase_admin import ml +from tests import testutils + + +BASE_URL = 'https://firebaseml.googleapis.com/v1beta2/' +HEADER_CLIENT_KEY = 'X-FIREBASE-CLIENT' +HEADER_CLIENT_VALUE = 'fire-admin-python/{0}'.format(firebase_admin.__version__) +PROJECT_ID = 'my-project-1' + +PAGE_TOKEN = 'pageToken' +NEXT_PAGE_TOKEN = 'nextPageToken' + +CREATE_TIME = '2020-01-21T20:44:27.392932Z' +CREATE_TIME_MILLIS = 1579639467392 + +UPDATE_TIME = '2020-01-21T22:45:29.392932Z' +UPDATE_TIME_MILLIS = 1579646729392 + +CREATE_TIME_2 = '2020-01-21T21:44:27.392932Z' +UPDATE_TIME_2 = '2020-01-21T23:45:29.392932Z' + +ETAG = '33a64df551425fcc55e4d42a148795d9f25f89d4' +MODEL_HASH = '987987a98b98798d098098e09809fc0893897' +TAG_1 = 'Tag1' +TAG_2 = 'Tag2' +TAG_3 = 'Tag3' +TAGS = [TAG_1, TAG_2] +TAGS_2 = [TAG_1, TAG_3] + +MODEL_ID_1 = 'modelId1' +MODEL_NAME_1 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1) +DISPLAY_NAME_1 = 'displayName1' +MODEL_JSON_1 = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_1 +} +MODEL_1 = ml.Model.from_dict(MODEL_JSON_1) + +MODEL_ID_2 = 'modelId2' +MODEL_NAME_2 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_2) +DISPLAY_NAME_2 = 'displayName2' +MODEL_JSON_2 = { + 'name': MODEL_NAME_2, + 'displayName': DISPLAY_NAME_2 +} +MODEL_2 = ml.Model.from_dict(MODEL_JSON_2) + +MODEL_ID_3 = 'modelId3' +MODEL_NAME_3 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_3) +DISPLAY_NAME_3 = 'displayName3' +MODEL_JSON_3 = { + 'name': MODEL_NAME_3, + 'displayName': DISPLAY_NAME_3 +} +MODEL_3 = ml.Model.from_dict(MODEL_JSON_3) + +MODEL_STATE_PUBLISHED_JSON = { + 'published': True +} +VALIDATION_ERROR_CODE = 400 +VALIDATION_ERROR_MSG = 'No model format found for {0}.'.format(MODEL_ID_1) +MODEL_STATE_ERROR_JSON = { + 'validationError': { + 'code': VALIDATION_ERROR_CODE, + 'message': VALIDATION_ERROR_MSG, + } +} + +OPERATION_NAME_1 = 'projects/{0}/operations/123'.format(PROJECT_ID) +OPERATION_NOT_DONE_JSON_1 = { + 'name': OPERATION_NAME_1, + 'metadata': { + '@type': 'type.googleapis.com/google.firebase.ml.v1beta2.ModelOperationMetadata', + 'name': 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1), + 'basic_operation_status': 'BASIC_OPERATION_STATUS_UPLOADING' + } +} + +GCS_BUCKET_NAME = 'my_bucket' +GCS_BLOB_NAME = 'mymodel.tflite' +GCS_TFLITE_URI = 'gs://{0}/{1}'.format(GCS_BUCKET_NAME, GCS_BLOB_NAME) +GCS_TFLITE_URI_JSON = {'gcsTfliteUri': GCS_TFLITE_URI} +GCS_TFLITE_MODEL_SOURCE = ml.TFLiteGCSModelSource(GCS_TFLITE_URI) +TFLITE_FORMAT_JSON = { + 'gcsTfliteUri': GCS_TFLITE_URI, + 'sizeBytes': '1234567' +} +TFLITE_FORMAT = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON) + +GCS_TFLITE_SIGNED_URI_PATTERN = ( + 'https://storage.googleapis.com/{0}/{1}?X-Goog-Algorithm=GOOG4-RSA-SHA256&foo') +GCS_TFLITE_SIGNED_URI = GCS_TFLITE_SIGNED_URI_PATTERN.format(GCS_BUCKET_NAME, GCS_BLOB_NAME) + +GCS_TFLITE_URI_2 = 'gs://my_bucket/mymodel2.tflite' +GCS_TFLITE_URI_JSON_2 = {'gcsTfliteUri': GCS_TFLITE_URI_2} +GCS_TFLITE_MODEL_SOURCE_2 = ml.TFLiteGCSModelSource(GCS_TFLITE_URI_2) +TFLITE_FORMAT_JSON_2 = { + 'gcsTfliteUri': GCS_TFLITE_URI_2, + 'sizeBytes': '2345678' +} +TFLITE_FORMAT_2 = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_2) + +CREATED_UPDATED_MODEL_JSON_1 = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_1, + 'createTime': CREATE_TIME, + 'updateTime': UPDATE_TIME, + 'state': MODEL_STATE_ERROR_JSON, + 'etag': ETAG, + 'modelHash': MODEL_HASH, + 'tags': TAGS, +} +CREATED_UPDATED_MODEL_1 = ml.Model.from_dict(CREATED_UPDATED_MODEL_JSON_1) + +LOCKED_MODEL_JSON_1 = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_1, + 'createTime': CREATE_TIME, + 'updateTime': UPDATE_TIME, + 'tags': TAGS, + 'activeOperations': [OPERATION_NOT_DONE_JSON_1] +} + +LOCKED_MODEL_JSON_2 = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_2, + 'createTime': CREATE_TIME_2, + 'updateTime': UPDATE_TIME_2, + 'tags': TAGS_2, + 'activeOperations': [OPERATION_NOT_DONE_JSON_1] +} + +OPERATION_DONE_MODEL_JSON_1 = { + 'done': True, + 'response': CREATED_UPDATED_MODEL_JSON_1 +} +OPERATION_MALFORMED_JSON_1 = { + 'done': True, + # if done is true then either response or error should be populated +} +OPERATION_MISSING_NAME = { + # Name is required if the operation is not done. + 'done': False +} +OPERATION_ERROR_CODE = 3 +OPERATION_ERROR_MSG = "Invalid argument" +OPERATION_ERROR_EXPECTED_STATUS = 'INVALID_ARGUMENT' +OPERATION_ERROR_JSON_1 = { + 'done': True, + 'error': { + 'code': OPERATION_ERROR_CODE, + 'message': OPERATION_ERROR_MSG, + } +} + +FULL_MODEL_ERR_STATE_LRO_JSON = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_1, + 'createTime': CREATE_TIME, + 'updateTime': UPDATE_TIME, + 'state': MODEL_STATE_ERROR_JSON, + 'etag': ETAG, + 'modelHash': MODEL_HASH, + 'tags': TAGS, + 'activeOperations': [OPERATION_NOT_DONE_JSON_1], +} +FULL_MODEL_PUBLISHED_JSON = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_1, + 'createTime': CREATE_TIME, + 'updateTime': UPDATE_TIME, + 'state': MODEL_STATE_PUBLISHED_JSON, + 'etag': ETAG, + 'modelHash': MODEL_HASH, + 'tags': TAGS, + 'tfliteModel': TFLITE_FORMAT_JSON +} +FULL_MODEL_PUBLISHED = ml.Model.from_dict(FULL_MODEL_PUBLISHED_JSON) +OPERATION_DONE_FULL_MODEL_PUBLISHED_JSON = { + 'name': OPERATION_NAME_1, + 'done': True, + 'response': FULL_MODEL_PUBLISHED_JSON +} + +EMPTY_RESPONSE = json.dumps({}) +OPERATION_NOT_DONE_RESPONSE = json.dumps(OPERATION_NOT_DONE_JSON_1) +OPERATION_DONE_RESPONSE = json.dumps(OPERATION_DONE_MODEL_JSON_1) +OPERATION_DONE_PUBLISHED_RESPONSE = json.dumps(OPERATION_DONE_FULL_MODEL_PUBLISHED_JSON) +OPERATION_ERROR_RESPONSE = json.dumps(OPERATION_ERROR_JSON_1) +OPERATION_MALFORMED_RESPONSE = json.dumps(OPERATION_MALFORMED_JSON_1) +OPERATION_MISSING_NAME_RESPONSE = json.dumps(OPERATION_MISSING_NAME) +DEFAULT_GET_RESPONSE = json.dumps(MODEL_JSON_1) +LOCKED_MODEL_2_RESPONSE = json.dumps(LOCKED_MODEL_JSON_2) +NO_MODELS_LIST_RESPONSE = json.dumps({}) +DEFAULT_LIST_RESPONSE = json.dumps({ + 'models': [MODEL_JSON_1, MODEL_JSON_2], + 'nextPageToken': NEXT_PAGE_TOKEN +}) +LAST_PAGE_LIST_RESPONSE = json.dumps({ + 'models': [MODEL_JSON_3] +}) +ONE_PAGE_LIST_RESPONSE = json.dumps({ + 'models': [MODEL_JSON_1, MODEL_JSON_2, MODEL_JSON_3], +}) + +ERROR_CODE_NOT_FOUND = 404 +ERROR_MSG_NOT_FOUND = 'The resource was not found' +ERROR_STATUS_NOT_FOUND = 'NOT_FOUND' +ERROR_JSON_NOT_FOUND = { + 'error': { + 'code': ERROR_CODE_NOT_FOUND, + 'message': ERROR_MSG_NOT_FOUND, + 'status': ERROR_STATUS_NOT_FOUND + } +} +ERROR_RESPONSE_NOT_FOUND = json.dumps(ERROR_JSON_NOT_FOUND) + +ERROR_CODE_BAD_REQUEST = 400 +ERROR_MSG_BAD_REQUEST = 'Invalid Argument' +ERROR_STATUS_BAD_REQUEST = 'INVALID_ARGUMENT' +ERROR_JSON_BAD_REQUEST = { + 'error': { + 'code': ERROR_CODE_BAD_REQUEST, + 'message': ERROR_MSG_BAD_REQUEST, + 'status': ERROR_STATUS_BAD_REQUEST + } +} +ERROR_RESPONSE_BAD_REQUEST = json.dumps(ERROR_JSON_BAD_REQUEST) + +INVALID_MODEL_ID_ARGS = [ + ('', ValueError), + ('&_*#@:/?', ValueError), + (None, TypeError), + (12345, TypeError), +] +INVALID_MODEL_ARGS = [ + 'abc', + 4.2, + list(), + dict(), + True, + -1, + 0, + None +] +INVALID_OP_NAME_ARGS = [ + 'abc', + '123', + 'operations/project/1234/model/abc/operation/123', + 'projects/operations/123', + 'projects/$#@/operations/123', + 'projects/1234/operations/123/extrathing', +] +PAGE_SIZE_VALUE_ERROR_MSG = 'Page size must be a positive integer between ' \ + '1 and {0}'.format(ml._MAX_PAGE_SIZE) +INVALID_STRING_OR_NONE_ARGS = [0, -1, 4.2, 0x10, False, list(), dict()] + + +# For validation type errors +def check_error(excinfo, err_type, msg=None): + err = excinfo.value + assert isinstance(err, err_type) + if msg: + assert str(err) == msg + + +# For errors that are returned in an operation +def check_operation_error(excinfo, code, msg): + err = excinfo.value + assert isinstance(err, exceptions.FirebaseError) + assert err.code == code + assert str(err) == msg + + +# For rpc errors +def check_firebase_error(excinfo, code, status, msg): + err = excinfo.value + assert isinstance(err, exceptions.FirebaseError) + assert err.code == code + assert err.http_response is not None + assert err.http_response.status_code == status + assert str(err) == msg + + +def instrument_ml_service(status=200, payload=None, operations=False, app=None): + if not app: + app = firebase_admin.get_app() + ml_service = ml._get_ml_service(app) + recorder = [] + session_url = 'https://firebaseml.googleapis.com/v1beta2/' + + if isinstance(status, list): + adapter = testutils.MockMultiRequestAdapter + else: + adapter = testutils.MockAdapter + + if operations: + ml_service._operation_client.session.mount( + session_url, adapter(payload, status, recorder)) + else: + ml_service._client.session.mount( + session_url, adapter(payload, status, recorder)) + return recorder + +class _TestStorageClient: + @staticmethod + def upload(bucket_name, model_file_name, app): + del app # unused variable + blob_name = ml._CloudStorageClient.BLOB_NAME.format(model_file_name) + return ml._CloudStorageClient.GCS_URI.format(bucket_name, blob_name) + + @staticmethod + def sign_uri(gcs_tflite_uri, app): + del app # unused variable + bucket_name, blob_name = ml._CloudStorageClient._parse_gcs_tflite_uri(gcs_tflite_uri) + return GCS_TFLITE_SIGNED_URI_PATTERN.format(bucket_name, blob_name) + +class TestModel: + """Tests ml.Model class.""" + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + ml._MLService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test + ml.TFLiteGCSModelSource._STORAGE_CLIENT = _TestStorageClient() + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def _op_url(project_id): + return BASE_URL + \ + 'projects/{0}/operations/123'.format(project_id) + + def test_model_success_err_state_lro(self): + model = ml.Model.from_dict(FULL_MODEL_ERR_STATE_LRO_JSON) + assert model.model_id == MODEL_ID_1 + assert model.display_name == DISPLAY_NAME_1 + assert model.create_time == CREATE_TIME_MILLIS + assert model.update_time == UPDATE_TIME_MILLIS + assert model.validation_error == VALIDATION_ERROR_MSG + assert model.published is False + assert model.etag == ETAG + assert model.model_hash == MODEL_HASH + assert model.tags == TAGS + assert model.locked is True + assert model.model_format is None + assert model.as_dict() == FULL_MODEL_ERR_STATE_LRO_JSON + + def test_model_success_published(self): + model = ml.Model.from_dict(FULL_MODEL_PUBLISHED_JSON) + assert model.model_id == MODEL_ID_1 + assert model.display_name == DISPLAY_NAME_1 + assert model.create_time == CREATE_TIME_MILLIS + assert model.update_time == UPDATE_TIME_MILLIS + assert model.validation_error is None + assert model.published is True + assert model.etag == ETAG + assert model.model_hash == MODEL_HASH + assert model.tags == TAGS + assert model.locked is False + assert model.model_format == TFLITE_FORMAT + assert model.as_dict() == FULL_MODEL_PUBLISHED_JSON + + def test_model_keyword_based_creation_and_setters(self): + model = ml.Model(display_name=DISPLAY_NAME_1, tags=TAGS, model_format=TFLITE_FORMAT) + assert model.display_name == DISPLAY_NAME_1 + assert model.tags == TAGS + assert model.model_format == TFLITE_FORMAT + assert model.as_dict() == { + 'displayName': DISPLAY_NAME_1, + 'tags': TAGS, + 'tfliteModel': TFLITE_FORMAT_JSON + } + + model.display_name = DISPLAY_NAME_2 + model.tags = TAGS_2 + model.model_format = TFLITE_FORMAT_2 + assert model.as_dict() == { + 'displayName': DISPLAY_NAME_2, + 'tags': TAGS_2, + 'tfliteModel': TFLITE_FORMAT_JSON_2 + } + + def test_model_format_source_creation(self): + model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) + model_format = ml.TFLiteFormat(model_source=model_source) + model = ml.Model(display_name=DISPLAY_NAME_1, model_format=model_format) + assert model.as_dict() == { + 'displayName': DISPLAY_NAME_1, + 'tfliteModel': { + 'gcsTfliteUri': GCS_TFLITE_URI + } + } + + def test_source_creation_from_tflite_file(self): + model_source = ml.TFLiteGCSModelSource.from_tflite_model_file( + "my_model.tflite", "my_bucket") + assert model_source.as_dict() == { + 'gcsTfliteUri': 'gs://my_bucket/Firebase/ML/Models/my_model.tflite' + } + + def test_model_source_setters(self): + model_source = ml.TFLiteGCSModelSource(GCS_TFLITE_URI) + model_source.gcs_tflite_uri = GCS_TFLITE_URI_2 + assert model_source.gcs_tflite_uri == GCS_TFLITE_URI_2 + assert model_source.as_dict() == GCS_TFLITE_URI_JSON_2 + + def test_model_format_setters(self): + model_format = ml.TFLiteFormat(model_source=GCS_TFLITE_MODEL_SOURCE) + model_format.model_source = GCS_TFLITE_MODEL_SOURCE_2 + assert model_format.model_source == GCS_TFLITE_MODEL_SOURCE_2 + assert model_format.as_dict() == { + 'tfliteModel': { + 'gcsTfliteUri': GCS_TFLITE_URI_2 + } + } + + def test_model_as_dict_for_upload(self): + model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) + model_format = ml.TFLiteFormat(model_source=model_source) + model = ml.Model(display_name=DISPLAY_NAME_1, model_format=model_format) + assert model.as_dict(for_upload=True) == { + 'displayName': DISPLAY_NAME_1, + 'tfliteModel': { + 'gcsTfliteUri': GCS_TFLITE_SIGNED_URI + } + } + + @pytest.mark.parametrize('helper_func', [ + ml.TFLiteGCSModelSource.from_keras_model, + ml.TFLiteGCSModelSource.from_saved_model + ]) + def test_tf_not_enabled(self, helper_func): + ml._TF_ENABLED = False # for reliability + with pytest.raises(ImportError) as excinfo: + helper_func(None) + check_error(excinfo, ImportError) + + @pytest.mark.parametrize('display_name, exc_type', [ + ('', ValueError), + ('&_*#@:/?', ValueError), + (12345, TypeError) + ]) + def test_model_display_name_validation_errors(self, display_name, exc_type): + with pytest.raises(exc_type) as excinfo: + ml.Model(display_name=display_name) + check_error(excinfo, exc_type) + + @pytest.mark.parametrize('tags, exc_type, error_message', [ + ('tag1', TypeError, 'Tags must be a list of strings.'), + (123, TypeError, 'Tags must be a list of strings.'), + (['tag1', 123, 'tag2'], TypeError, 'Tags must be a list of strings.'), + (['tag1', '@#$%^&'], ValueError, 'Tag format is invalid.'), + (['', 'tag2'], ValueError, 'Tag format is invalid.'), + (['sixty-one_characters_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx', + 'tag2'], ValueError, 'Tag format is invalid.') + ]) + def test_model_tags_validation_errors(self, tags, exc_type, error_message): + with pytest.raises(exc_type) as excinfo: + ml.Model(tags=tags) + check_error(excinfo, exc_type, error_message) + + @pytest.mark.parametrize('model_format', [ + 123, + "abc", + {}, + [], + True + ]) + def test_model_format_validation_errors(self, model_format): + with pytest.raises(TypeError) as excinfo: + ml.Model(model_format=model_format) + check_error(excinfo, TypeError, 'Model format must be a ModelFormat object.') + + @pytest.mark.parametrize('model_source', [ + 123, + "abc", + {}, + [], + True + ]) + def test_model_source_validation_errors(self, model_source): + with pytest.raises(TypeError) as excinfo: + ml.TFLiteFormat(model_source=model_source) + check_error(excinfo, TypeError, 'Model source must be a TFLiteModelSource object.') + + @pytest.mark.parametrize('uri, exc_type', [ + (123, TypeError), + ('abc', ValueError), + ('gs://NO_CAPITALS', ValueError), + ('gs://abc/', ValueError), + ('gs://aa/model.tflite', ValueError), + ('gs://@#$%/model.tflite', ValueError), + ('gs://invalid space/model.tflite', ValueError), + ('gs://sixty-four-characters_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx/model.tflite', + ValueError) + ]) + def test_gcs_tflite_source_validation_errors(self, uri, exc_type): + with pytest.raises(exc_type) as excinfo: + ml.TFLiteGCSModelSource(gcs_tflite_uri=uri) + check_error(excinfo, exc_type) + + def test_wait_for_unlocked_not_locked(self): + model = ml.Model(display_name="not_locked") + model.wait_for_unlocked() + + def test_wait_for_unlocked(self): + recorder = instrument_ml_service(status=200, + operations=True, + payload=OPERATION_DONE_PUBLISHED_RESPONSE) + model = ml.Model.from_dict(LOCKED_MODEL_JSON_1) + model.wait_for_unlocked() + assert model == FULL_MODEL_PUBLISHED + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == TestModel._op_url(PROJECT_ID) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + + def test_wait_for_unlocked_timeout(self): + recorder = instrument_ml_service( + status=200, operations=True, payload=OPERATION_NOT_DONE_RESPONSE) + ml._MLService.POLL_BASE_WAIT_TIME_SECONDS = 3 # longer so timeout applies immediately + model = ml.Model.from_dict(LOCKED_MODEL_JSON_1) + with pytest.raises(Exception) as excinfo: + model.wait_for_unlocked(max_time_seconds=0.1) + check_error(excinfo, exceptions.DeadlineExceededError, 'Polling max time exceeded.') + assert len(recorder) == 1 + + +class TestCreateModel: + """Tests ml.create_model.""" + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + ml._MLService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def _url(project_id): + return BASE_URL + 'projects/{0}/models'.format(project_id) + + @staticmethod + def _op_url(project_id): + return BASE_URL + \ + 'projects/{0}/operations/123'.format(project_id) + + @staticmethod + def _get_url(project_id, model_id): + return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + + def test_immediate_done(self): + instrument_ml_service(status=200, payload=OPERATION_DONE_RESPONSE) + model = ml.create_model(MODEL_1) + assert model == CREATED_UPDATED_MODEL_1 + + def test_returns_locked(self): + recorder = instrument_ml_service( + status=[200, 200], + payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE]) + expected_model = ml.Model.from_dict(LOCKED_MODEL_JSON_2) + model = ml.create_model(MODEL_1) + + assert model == expected_model + assert len(recorder) == 2 + assert recorder[0].method == 'POST' + assert recorder[0].url == TestCreateModel._url(PROJECT_ID) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + assert recorder[1].method == 'GET' + assert recorder[1].url == TestCreateModel._get_url(PROJECT_ID, MODEL_ID_1) + assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + + def test_operation_error(self): + instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE) + with pytest.raises(Exception) as excinfo: + ml.create_model(MODEL_1) + # The http request succeeded, the operation returned contains a create failure + check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) + + def test_malformed_operation(self): + instrument_ml_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) + with pytest.raises(Exception) as excinfo: + ml.create_model(MODEL_1) + check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.') + + def test_rpc_error_create(self): + create_recorder = instrument_ml_service( + status=400, payload=ERROR_RESPONSE_BAD_REQUEST) + with pytest.raises(Exception) as excinfo: + ml.create_model(MODEL_1) + check_firebase_error( + excinfo, + ERROR_STATUS_BAD_REQUEST, + ERROR_CODE_BAD_REQUEST, + ERROR_MSG_BAD_REQUEST + ) + assert len(create_recorder) == 1 + + @pytest.mark.parametrize('model', INVALID_MODEL_ARGS) + def test_not_model(self, model): + with pytest.raises(Exception) as excinfo: + ml.create_model(model) + check_error(excinfo, TypeError, 'Model must be an ml.Model.') + + def test_missing_display_name(self): + with pytest.raises(Exception) as excinfo: + ml.create_model(ml.Model.from_dict({})) + check_error(excinfo, ValueError, 'Model must have a display name.') + + def test_missing_op_name(self): + instrument_ml_service(status=200, payload=OPERATION_MISSING_NAME_RESPONSE) + with pytest.raises(Exception) as excinfo: + ml.create_model(MODEL_1) + check_error(excinfo, TypeError) + + @pytest.mark.parametrize('op_name', INVALID_OP_NAME_ARGS) + def test_invalid_op_name(self, op_name): + payload = json.dumps({'name': op_name}) + instrument_ml_service(status=200, payload=payload) + with pytest.raises(Exception) as excinfo: + ml.create_model(MODEL_1) + check_error(excinfo, ValueError, 'Operation name format is invalid.') + + +class TestUpdateModel: + """Tests ml.update_model.""" + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + ml._MLService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def _url(project_id, model_id): + return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + + @staticmethod + def _op_url(project_id): + return BASE_URL + \ + 'projects/{0}/operations/123'.format(project_id) + + def test_immediate_done(self): + instrument_ml_service(status=200, payload=OPERATION_DONE_RESPONSE) + model = ml.update_model(MODEL_1) + assert model == CREATED_UPDATED_MODEL_1 + + def test_returns_locked(self): + recorder = instrument_ml_service( + status=[200, 200], + payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE]) + expected_model = ml.Model.from_dict(LOCKED_MODEL_JSON_2) + model = ml.update_model(MODEL_1) + + assert model == expected_model + assert len(recorder) == 2 + assert recorder[0].method == 'PATCH' + assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + assert recorder[1].method == 'GET' + assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + + def test_operation_error(self): + instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE) + with pytest.raises(Exception) as excinfo: + ml.update_model(MODEL_1) + # The http request succeeded, the operation returned contains an update failure + check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) + + def test_malformed_operation(self): + instrument_ml_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) + with pytest.raises(Exception) as excinfo: + ml.update_model(MODEL_1) + check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.') + + def test_rpc_error(self): + create_recorder = instrument_ml_service( + status=400, payload=ERROR_RESPONSE_BAD_REQUEST) + with pytest.raises(Exception) as excinfo: + ml.update_model(MODEL_1) + check_firebase_error( + excinfo, + ERROR_STATUS_BAD_REQUEST, + ERROR_CODE_BAD_REQUEST, + ERROR_MSG_BAD_REQUEST + ) + assert len(create_recorder) == 1 + + @pytest.mark.parametrize('model', INVALID_MODEL_ARGS) + def test_not_model(self, model): + with pytest.raises(Exception) as excinfo: + ml.update_model(model) + check_error(excinfo, TypeError, 'Model must be an ml.Model.') + + def test_missing_display_name(self): + with pytest.raises(Exception) as excinfo: + ml.update_model(ml.Model.from_dict({})) + check_error(excinfo, ValueError, 'Model must have a display name.') + + def test_missing_op_name(self): + instrument_ml_service(status=200, payload=OPERATION_MISSING_NAME_RESPONSE) + with pytest.raises(Exception) as excinfo: + ml.update_model(MODEL_1) + check_error(excinfo, TypeError) + + @pytest.mark.parametrize('op_name', INVALID_OP_NAME_ARGS) + def test_invalid_op_name(self, op_name): + payload = json.dumps({'name': op_name}) + instrument_ml_service(status=200, payload=payload) + with pytest.raises(Exception) as excinfo: + ml.update_model(MODEL_1) + check_error(excinfo, ValueError, 'Operation name format is invalid.') + + +class TestPublishUnpublish: + """Tests ml.publish_model and ml.unpublish_model.""" + + PUBLISH_UNPUBLISH_WITH_ARGS = [ + (ml.publish_model, True), + (ml.unpublish_model, False) + ] + PUBLISH_UNPUBLISH_FUNCS = [item[0] for item in PUBLISH_UNPUBLISH_WITH_ARGS] + + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + ml._MLService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def _update_url(project_id, model_id): + update_url = 'projects/{0}/models/{1}?updateMask=state.published'.format( + project_id, model_id) + return BASE_URL + update_url + + @staticmethod + def _get_url(project_id, model_id): + return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + + @staticmethod + def _op_url(project_id): + return BASE_URL + \ + 'projects/{0}/operations/123'.format(project_id) + + @pytest.mark.parametrize('publish_function, published', PUBLISH_UNPUBLISH_WITH_ARGS) + def test_immediate_done(self, publish_function, published): + recorder = instrument_ml_service(status=200, payload=OPERATION_DONE_RESPONSE) + model = publish_function(MODEL_ID_1) + assert model == CREATED_UPDATED_MODEL_1 + assert len(recorder) == 1 + assert recorder[0].method == 'PATCH' + assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + body = json.loads(recorder[0].body.decode()) + assert body.get('state', {}).get('published', None) is published + + @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) + def test_returns_locked(self, publish_function): + recorder = instrument_ml_service( + status=[200, 200], + payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE]) + expected_model = ml.Model.from_dict(LOCKED_MODEL_JSON_2) + model = publish_function(MODEL_ID_1) + + assert model == expected_model + assert len(recorder) == 2 + assert recorder[0].method == 'PATCH' + assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + assert recorder[1].method == 'GET' + assert recorder[1].url == TestPublishUnpublish._get_url(PROJECT_ID, MODEL_ID_1) + assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + + @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) + def test_operation_error(self, publish_function): + instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE) + with pytest.raises(Exception) as excinfo: + publish_function(MODEL_ID_1) + # The http request succeeded, the operation returned contains an update failure + check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) + + @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) + def test_malformed_operation(self, publish_function): + instrument_ml_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) + with pytest.raises(Exception) as excinfo: + publish_function(MODEL_ID_1) + check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.') + + @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) + def test_rpc_error(self, publish_function): + create_recorder = instrument_ml_service( + status=400, payload=ERROR_RESPONSE_BAD_REQUEST) + with pytest.raises(Exception) as excinfo: + publish_function(MODEL_ID_1) + check_firebase_error( + excinfo, + ERROR_STATUS_BAD_REQUEST, + ERROR_CODE_BAD_REQUEST, + ERROR_MSG_BAD_REQUEST + ) + assert len(create_recorder) == 1 + + +class TestGetModel: + """Tests ml.get_model.""" + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def _url(project_id, model_id): + return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + + def test_get_model(self): + recorder = instrument_ml_service(status=200, payload=DEFAULT_GET_RESPONSE) + model = ml.get_model(MODEL_ID_1) + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + assert model == MODEL_1 + assert model.model_id == MODEL_ID_1 + assert model.display_name == DISPLAY_NAME_1 + + @pytest.mark.parametrize('model_id, exc_type', INVALID_MODEL_ID_ARGS) + def test_get_model_validation_errors(self, model_id, exc_type): + with pytest.raises(exc_type) as excinfo: + ml.get_model(model_id) + check_error(excinfo, exc_type) + + def test_get_model_error(self): + recorder = instrument_ml_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.get_model(MODEL_ID_1) + check_firebase_error( + excinfo, + ERROR_STATUS_NOT_FOUND, + ERROR_CODE_NOT_FOUND, + ERROR_MSG_NOT_FOUND + ) + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + + def test_no_project_id(self): + def evaluate(): + app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') + with pytest.raises(ValueError): + ml.get_model(MODEL_ID_1, app) + testutils.run_without_project_id(evaluate) + + +class TestDeleteModel: + """Tests ml.delete_model.""" + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def _url(project_id, model_id): + return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + + def test_delete_model(self): + recorder = instrument_ml_service(status=200, payload=EMPTY_RESPONSE) + ml.delete_model(MODEL_ID_1) # no response for delete + assert len(recorder) == 1 + assert recorder[0].method == 'DELETE' + assert recorder[0].url == TestDeleteModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + + @pytest.mark.parametrize('model_id, exc_type', INVALID_MODEL_ID_ARGS) + def test_delete_model_validation_errors(self, model_id, exc_type): + with pytest.raises(exc_type) as excinfo: + ml.delete_model(model_id) + check_error(excinfo, exc_type) + + def test_delete_model_error(self): + recorder = instrument_ml_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.delete_model(MODEL_ID_1) + check_firebase_error( + excinfo, + ERROR_STATUS_NOT_FOUND, + ERROR_CODE_NOT_FOUND, + ERROR_MSG_NOT_FOUND + ) + assert len(recorder) == 1 + assert recorder[0].method == 'DELETE' + assert recorder[0].url == self._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + + def test_no_project_id(self): + def evaluate(): + app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') + with pytest.raises(ValueError): + ml.delete_model(MODEL_ID_1, app) + testutils.run_without_project_id(evaluate) + + +class TestListModels: + """Tests ml.list_models.""" + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def _url(project_id): + return BASE_URL + 'projects/{0}/models'.format(project_id) + + @staticmethod + def _check_page(page, model_count): + assert isinstance(page, ml.ListModelsPage) + assert len(page.models) == model_count + for model in page.models: + assert isinstance(model, ml.Model) + + def test_list_models_no_args(self): + recorder = instrument_ml_service(status=200, payload=DEFAULT_LIST_RESPONSE) + models_page = ml.list_models() + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == TestListModels._url(PROJECT_ID) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + TestListModels._check_page(models_page, 2) + assert models_page.has_next_page + assert models_page.next_page_token == NEXT_PAGE_TOKEN + assert models_page.models[0] == MODEL_1 + assert models_page.models[1] == MODEL_2 + + def test_list_models_with_all_args(self): + recorder = instrument_ml_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) + models_page = ml.list_models( + 'display_name=displayName3', + page_size=10, + page_token=PAGE_TOKEN) + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == ( + TestListModels._url(PROJECT_ID) + + '?filter=display_name%3DdisplayName3&page_size=10&page_token={0}' + .format(PAGE_TOKEN)) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + assert isinstance(models_page, ml.ListModelsPage) + assert len(models_page.models) == 1 + assert models_page.models[0] == MODEL_3 + assert not models_page.has_next_page + + @pytest.mark.parametrize('list_filter', INVALID_STRING_OR_NONE_ARGS) + def test_list_models_list_filter_validation(self, list_filter): + with pytest.raises(TypeError) as excinfo: + ml.list_models(list_filter=list_filter) + check_error(excinfo, TypeError, 'List filter must be a string or None.') + + @pytest.mark.parametrize('page_size, exc_type, error_message', [ + ('abc', TypeError, 'Page size must be a number or None.'), + (4.2, TypeError, 'Page size must be a number or None.'), + (list(), TypeError, 'Page size must be a number or None.'), + (dict(), TypeError, 'Page size must be a number or None.'), + (True, TypeError, 'Page size must be a number or None.'), + (-1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), + (0, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), + (ml._MAX_PAGE_SIZE + 1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG) + ]) + def test_list_models_page_size_validation(self, page_size, exc_type, error_message): + with pytest.raises(exc_type) as excinfo: + ml.list_models(page_size=page_size) + check_error(excinfo, exc_type, error_message) + + @pytest.mark.parametrize('page_token', INVALID_STRING_OR_NONE_ARGS) + def test_list_models_page_token_validation(self, page_token): + with pytest.raises(TypeError) as excinfo: + ml.list_models(page_token=page_token) + check_error(excinfo, TypeError, 'Page token must be a string or None.') + + def test_list_models_error(self): + recorder = instrument_ml_service(status=400, payload=ERROR_RESPONSE_BAD_REQUEST) + with pytest.raises(exceptions.InvalidArgumentError) as excinfo: + ml.list_models() + check_firebase_error( + excinfo, + ERROR_STATUS_BAD_REQUEST, + ERROR_CODE_BAD_REQUEST, + ERROR_MSG_BAD_REQUEST + ) + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == TestListModels._url(PROJECT_ID) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE + + def test_no_project_id(self): + def evaluate(): + app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') + with pytest.raises(ValueError): + ml.list_models(app=app) + testutils.run_without_project_id(evaluate) + + def test_list_single_page(self): + recorder = instrument_ml_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) + models_page = ml.list_models() + assert len(recorder) == 1 + assert models_page.next_page_token == '' + assert models_page.has_next_page is False + assert models_page.get_next_page() is None + models = [model for model in models_page.iterate_all()] + assert len(models) == 1 + + def test_list_multiple_pages(self): + # Page 1 + recorder = instrument_ml_service(status=200, payload=DEFAULT_LIST_RESPONSE) + page = ml.list_models() + assert len(recorder) == 1 + assert len(page.models) == 2 + assert page.next_page_token == NEXT_PAGE_TOKEN + assert page.has_next_page is True + + # Page 2 + recorder = instrument_ml_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) + page_2 = page.get_next_page() + assert len(recorder) == 1 + assert len(page_2.models) == 1 + assert page_2.next_page_token == '' + assert page_2.has_next_page is False + assert page_2.get_next_page() is None + + def test_list_models_paged_iteration(self): + # Page 1 + recorder = instrument_ml_service(status=200, payload=DEFAULT_LIST_RESPONSE) + page = ml.list_models() + assert page.next_page_token == NEXT_PAGE_TOKEN + assert page.has_next_page is True + iterator = page.iterate_all() + for index in range(2): + model = next(iterator) + assert model.display_name == 'displayName{0}'.format(index+1) + assert len(recorder) == 1 + + # Page 2 + recorder = instrument_ml_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) + model = next(iterator) + assert model.display_name == DISPLAY_NAME_3 + with pytest.raises(StopIteration): + next(iterator) + + def test_list_models_stop_iteration(self): + recorder = instrument_ml_service(status=200, payload=ONE_PAGE_LIST_RESPONSE) + page = ml.list_models() + assert len(recorder) == 1 + assert len(page.models) == 3 + iterator = page.iterate_all() + models = [model for model in iterator] + assert len(page.models) == 3 + with pytest.raises(StopIteration): + next(iterator) + assert len(models) == 3 + + def test_list_models_no_models(self): + recorder = instrument_ml_service(status=200, payload=NO_MODELS_LIST_RESPONSE) + page = ml.list_models() + assert len(recorder) == 1 + assert len(page.models) == 0 + models = [model for model in page.iterate_all()] + assert len(models) == 0