From 981285fc6f3c2067ac2255714e14811080dce386 Mon Sep 17 00:00:00 2001 From: ifielker Date: Thu, 12 Sep 2019 13:42:13 -0400 Subject: [PATCH 1/9] Firebase ML Kit TFLiteGCSModelSource.from_tflite_model implementation --- firebase_admin/mlkit.py | 40 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 91cedbedc..41a7ceb1c 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -18,6 +18,8 @@ deleting, publishing and unpublishing Firebase ML Kit models. """ + + import datetime import numbers import re @@ -30,6 +32,12 @@ from firebase_admin import _utils from firebase_admin import exceptions +# pylint: disable=import-error,no-name-in-module +try: + from firebase_admin import storage + GCS_ENABLED = True +except ImportError: + GCS_ENABLED = False _MLKIT_ATTRIBUTE = '_mlkit' _MAX_PAGE_SIZE = 100 @@ -379,7 +387,10 @@ def as_dict(self): class TFLiteGCSModelSource(TFLiteModelSource): """TFLite model source representing a tflite model file stored in GCS.""" - def __init__(self, gcs_tflite_uri): + BLOB_NAME = 'Firebase/MLKit/Models/{0}' + + def __init__(self, gcs_tflite_uri, app=None): + self._app = app self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) def __eq__(self, other): @@ -391,6 +402,32 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + @classmethod + def from_tflite_model(cls, model_file_name, bucket_name=None, app=None): + """Uploads the model file to an existing Google Cloud Services bucket. + + Args: + model_file_name: The name of the model file. + bucket_name: The name of an existing bucket. None to use the default bucket configured + in the app. + app: A Firebase app instance (or None to use the default app). + + Returns: + TFLiteGCSModelSource: The source created from the model_file + + Raises: + ImportError: If the Cloud Storage Library has not been installed. + """ + if not GCS_ENABLED: + raise ImportError('Failed to import the Cloud Storage library for Python. Make sure ' + 'to install the "google-cloud-storage" module.') + bucket = storage.bucket(bucket_name, app=app) + blob_name = BLOB_NAME.format(model_file_name) + blob = bucket.blob(blob_name) + blob.upload_from_filename(model_file_name) + return TFLiteGCSModelSource(gcs_tflite_uri='gs://{0}/{1}'.format(bucket.name, blob_name), + app=app) + @property def gcs_tflite_uri(self): return self._gcs_tflite_uri @@ -402,6 +439,7 @@ def gcs_tflite_uri(self, gcs_tflite_uri): def as_dict(self): return {"gcsTfliteUri": self._gcs_tflite_uri} + #TODO(ifielker): implement from_saved_model etc. From bdabdc46c40d8b5d84a2b86e18afc6cee94851ea Mon Sep 17 00:00:00 2001 From: ifielker Date: Fri, 13 Sep 2019 15:15:10 -0400 Subject: [PATCH 2/9] fixed --- firebase_admin/mlkit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 41a7ceb1c..e8325e990 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -422,7 +422,7 @@ def from_tflite_model(cls, model_file_name, bucket_name=None, app=None): raise ImportError('Failed to import the Cloud Storage library for Python. Make sure ' 'to install the "google-cloud-storage" module.') bucket = storage.bucket(bucket_name, app=app) - blob_name = BLOB_NAME.format(model_file_name) + blob_name = TFLiteGCSModelSource.BLOB_NAME.format(model_file_name) blob = bucket.blob(blob_name) blob.upload_from_filename(model_file_name) return TFLiteGCSModelSource(gcs_tflite_uri='gs://{0}/{1}'.format(bucket.name, blob_name), From 921dd59344725fc1de016acd617607370c17933a Mon Sep 17 00:00:00 2001 From: ifielker Date: Mon, 16 Sep 2019 17:03:39 -0400 Subject: [PATCH 3/9] fixed tests --- firebase_admin/mlkit.py | 85 +++++++++++++++++++++++++++++++---------- tests/test_mlkit.py | 26 +++++++++++++ 2 files changed, 90 insertions(+), 21 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index e8325e990..974c38151 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -44,7 +44,8 @@ _MODEL_ID_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') _DISPLAY_NAME_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') _TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') -_GCS_TFLITE_URI_PATTERN = re.compile(r'^gs://[a-z0-9_.-]{3,63}/.+') +_GCS_TFLITE_URI_PATTERN = re.compile( + r'^gs://(?P[a-z0-9_.-]{3,63})/(?P.+)$') _RESOURCE_NAME_PATTERN = re.compile( r'^projects/(?P[^/]+)/models/(?P[A-Za-z0-9_-]{1,60})$') _OPERATION_NAME_PATTERN = re.compile( @@ -309,16 +310,16 @@ def model_format(self, model_format): self._model_format = model_format #Can be None return self - def as_dict(self): + def as_dict(self, for_upload=False): copy = dict(self._data) if self._model_format: - copy.update(self._model_format.as_dict()) + copy.update(self._model_format.as_dict(for_upload=for_upload)) return copy class ModelFormat(object): """Abstract base class representing a Model Format such as TFLite.""" - def as_dict(self): + def as_dict(self, for_upload=False): raise NotImplementedError @@ -372,22 +373,55 @@ def model_source(self, model_source): def size_bytes(self): return self._data.get('sizeBytes') - def as_dict(self): + def as_dict(self, for_upload=False): copy = dict(self._data) if self._model_source: - copy.update(self._model_source.as_dict()) + copy.update(self._model_source.as_dict(for_upload=for_upload)) return {'tfliteModel': copy} class TFLiteModelSource(object): """Abstract base class representing a model source for TFLite format models.""" - def as_dict(self): + def as_dict(self, for_upload=False): raise NotImplementedError +class _CloudStorageClient(object): + """Cloud Storage helper class""" + + GCS_URI = 'gs://{0}/{1}' + BLOB_NAME = 'Firebase/MLKit/Models/{0}' + + def __init__(self): + if not GCS_ENABLED: + raise ImportError('Failed to import the Cloud Storage library for Python. Make sure ' + 'to install the "google-cloud-storage" module.') + + @staticmethod + def upload(bucket_name, model_file_name, app): + bucket = storage.bucket(bucket_name, app=app) + blob_name = _CloudStorageClient.BLOB_NAME.format(model_file_name) + blob = bucket.blob(blob_name) + blob.upload_from_filename(model_file_name) + return GCS_URI.format(bucket.name, blob_name) + + @staticmethod + def sign_uri(gcs_tflite_uri, app): + """Makes the gcs_tflite_uri readable for GET for 10 minutes.""" + bucket_name, blob_name = _parse_gcs_tflite_uri(gcs_tflite_uri) + bucket = storage.bucket(bucket_name, app=app) + blob = bucket.blob(blob_name) + return blob.generate_signed_url( + version='v4', + expiration=datetime.timedelta(minutes=10), + method='GET' + ) + + class TFLiteGCSModelSource(TFLiteModelSource): """TFLite model source representing a tflite model file stored in GCS.""" - BLOB_NAME = 'Firebase/MLKit/Models/{0}' + + _STORAGE_CLIENT = _CloudStorageClient() def __init__(self, gcs_tflite_uri, app=None): self._app = app @@ -403,7 +437,7 @@ def __ne__(self, other): return not self.__eq__(other) @classmethod - def from_tflite_model(cls, model_file_name, bucket_name=None, app=None): + def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None): """Uploads the model file to an existing Google Cloud Services bucket. Args: @@ -418,15 +452,8 @@ def from_tflite_model(cls, model_file_name, bucket_name=None, app=None): Raises: ImportError: If the Cloud Storage Library has not been installed. """ - if not GCS_ENABLED: - raise ImportError('Failed to import the Cloud Storage library for Python. Make sure ' - 'to install the "google-cloud-storage" module.') - bucket = storage.bucket(bucket_name, app=app) - blob_name = TFLiteGCSModelSource.BLOB_NAME.format(model_file_name) - blob = bucket.blob(blob_name) - blob.upload_from_filename(model_file_name) - return TFLiteGCSModelSource(gcs_tflite_uri='gs://{0}/{1}'.format(bucket.name, blob_name), - app=app) + gcs_uri = TFLiteGCSModelSource._STORAGE_CLIENT.upload(bucket_name, model_file_name, app) + return TFLiteGCSModelSource(gcs_tflite_uri=gcs_uri, app=app) @property def gcs_tflite_uri(self): @@ -436,7 +463,14 @@ def gcs_tflite_uri(self): def gcs_tflite_uri(self, gcs_tflite_uri): self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) - def as_dict(self): + def _get_signed_gcs_tflite_uri(self): + """Signs the GCS uri, so the model file can be uploaded to Firebase ML Kit and verified.""" + return TFLiteGCSModelSource._STORAGE_CLIENT.sign_uri(self._gcs_tflite_uri, self._app) + + def as_dict(self, for_upload=False): + if for_upload: + return {"gcsTfliteUri": self._get_signed_gcs_tflite_uri()} + return {"gcsTfliteUri": self._gcs_tflite_uri} @@ -591,6 +625,15 @@ def _validate_gcs_tflite_uri(uri): return uri +def _parse_gcs_tflite_uri(uri): + # GCS Bucket naming rules are complex. The regex is not comprehensive. + # See https://cloud.google.com/storage/docs/naming for full details. + matcher = _GCS_TFLITE_URI_PATTERN.match(uri) + if not matcher: + raise ValueError('GCS TFLite URI format is invalid.') + return matcher.group('bucket_name'), matcher.group('blob_name') + + def _validate_model_format(model_format): if not isinstance(model_format, ModelFormat): raise TypeError('Model format must be a ModelFormat object.') @@ -709,13 +752,13 @@ def create_model(self, model): _validate_model(model) try: return self.handle_operation( - self._client.body('post', url='models', json=model.as_dict())) + self._client.body('post', url='models', json=model.as_dict(for_upload=True))) except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) def update_model(self, model, update_mask=None): _validate_model(model, update_mask) - data = {'model': model.as_dict()} + data = {'model': model.as_dict(for_upload=True)} if update_mask is not None: data['updateMask'] = update_mask try: diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 50fed4e1b..3d330c48b 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -21,6 +21,7 @@ import firebase_admin from firebase_admin import exceptions from firebase_admin import mlkit +from firebase_admin import storage from tests import testutils BASE_URL = 'https://mlkit.googleapis.com/v1beta1/' @@ -112,6 +113,10 @@ } TFLITE_FORMAT = mlkit.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON) +GCS_TFLITE_SIGNED_URI = 'gs://test_bucket/test_blob?signing_information' +GCS_TFLITE_SIGNED_URI_JSON = {'gcsTfliteUri': GCS_TFLITE_URI} +GCS_TFLITE_SIGNED_MODEL_SOURCE = mlkit.TFLiteGCSModelSource(GCS_TFLITE_SIGNED_URI) + GCS_TFLITE_URI_2 = 'gs://my_bucket/mymodel2.tflite' GCS_TFLITE_URI_JSON_2 = {'gcsTfliteUri': GCS_TFLITE_URI_2} GCS_TFLITE_MODEL_SOURCE_2 = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI_2) @@ -325,6 +330,14 @@ def instrument_mlkit_service(status=200, payload=None, operations=False, app=Non session_url, adapter(payload, status, recorder)) return recorder +class _TestStorageClient(object): + @staticmethod + def upload(bucket_name, model_file_name, app): + pass + + @staticmethod + def sign_uri(gcs_tflite_uri, app): + return GCS_TFLITE_SIGNED_URI class TestModel(object): """Tests mlkit.Model class.""" @@ -333,6 +346,7 @@ def setup_class(cls): cred = testutils.MockCredential() firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test + mlkit.TFLiteGCSModelSource._STORAGE_CLIENT = _TestStorageClient() @classmethod def teardown_class(cls): @@ -420,6 +434,17 @@ def test_model_format_setters(self): } } + def test_model_as_dict_for_upload(self): + model_source = mlkit.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) + model_format = mlkit.TFLiteFormat(model_source=model_source) + model = mlkit.Model(display_name=DISPLAY_NAME_1, model_format=model_format) + assert model.as_dict(for_upload=True) == { + 'displayName': DISPLAY_NAME_1, + 'tfliteModel': { + 'gcsTfliteUri': GCS_TFLITE_SIGNED_URI + } + } + @pytest.mark.parametrize('display_name, exc_type', [ ('', ValueError), ('&_*#@:/?', ValueError), @@ -803,6 +828,7 @@ def test_rpc_error(self, publish_function): ) assert len(create_recorder) == 1 + class TestGetModel(object): """Tests mlkit.get_model.""" @classmethod From 3da024e963a52e06bad70a153c5ea3c113acff23 Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 17 Sep 2019 13:14:03 -0400 Subject: [PATCH 4/9] support for tensorflow lite conversion helpers --- firebase_admin/mlkit.py | 68 ++++++++++++++++++++++++++++++++++++++--- tests/test_mlkit.py | 10 +++++- 2 files changed, 73 insertions(+), 5 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 974c38151..c2141e10a 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -39,6 +39,13 @@ except ImportError: GCS_ENABLED = False +# pylint: disable=import-error,no-name-in-module +try: + import tensorflow as tf + TF_ENABLED = True +except ImportError: + TF_ENABLED = False + _MLKIT_ATTRIBUTE = '_mlkit' _MAX_PAGE_SIZE = 100 _MODEL_ID_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') @@ -392,13 +399,11 @@ class _CloudStorageClient(object): GCS_URI = 'gs://{0}/{1}' BLOB_NAME = 'Firebase/MLKit/Models/{0}' - def __init__(self): + @staticmethod + def upload(bucket_name, model_file_name, app): if not GCS_ENABLED: raise ImportError('Failed to import the Cloud Storage library for Python. Make sure ' 'to install the "google-cloud-storage" module.') - - @staticmethod - def upload(bucket_name, model_file_name, app): bucket = storage.bucket(bucket_name, app=app) blob_name = _CloudStorageClient.BLOB_NAME.format(model_file_name) blob = bucket.blob(blob_name) @@ -408,6 +413,9 @@ def upload(bucket_name, model_file_name, app): @staticmethod def sign_uri(gcs_tflite_uri, app): """Makes the gcs_tflite_uri readable for GET for 10 minutes.""" + if not GCS_ENABLED: + raise ImportError('Failed to import the Cloud Storage library for Python. Make sure ' + 'to install the "google-cloud-storage" module.') bucket_name, blob_name = _parse_gcs_tflite_uri(gcs_tflite_uri) bucket = storage.bucket(bucket_name, app=app) blob = bucket.blob(blob_name) @@ -455,6 +463,58 @@ def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None): gcs_uri = TFLiteGCSModelSource._STORAGE_CLIENT.upload(bucket_name, model_file_name, app) return TFLiteGCSModelSource(gcs_tflite_uri=gcs_uri, app=app) + @classmethod + def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None): + """Creates a Tensor Flow Lite model from the saved model, and uploads the model to GCS. + + Args: + saved_model_dir: The saved model directory. + bucket_name: Optional. The name of the bucket to store the uploaded tflite file. + (or None to use the default bucket) + app: Optional. A Firebase app instance (or None to use the default app) + + Returns: + TFLiteGCSModelSource: The source created from the saved_model_dir + + Raises: + ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed. + """ + if not TF_ENABLED: + raise ImportError('Failed to import the tensorflow library for Python. Make sure ' + 'to install the tensorflow module.') + #TODO(ifielker): Do we need to worry about tf version? + converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) + tflite_model = converter.convert() + open("firebase_mlkit_model.tflite", "wb").write(tflite_model) + return from_tflite_model_file("firebase_mlkit_model.tflite", bucket_name, app) + + @classmethod + def from_keras_model(cls, keras_model, bucket_name=None, app=None): + """Creates a Tensor Flow Lite model from the keras model, and uploads the model to GCS. + + Args: + keras_model: A tf.keras model. + bucket_name: Optional. The name of the bucket to store the uploaded tflite file. + (or None to use the default bucket) + app: Optional. A Firebase app instance (or None to use the default app) + + Returns: + TFLiteGCSModelSource: The source created from the keras_model + + Raises: + ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed. + """ + if not TF_ENABLED: + raise ImportError('Failed to import the tensorflow library for Python. Make sure ' + 'to install the tensorflow module.') + #TODO(ifielker): Do we need to worry about tf version? + keras_file = "keras_model.h5" + tf.keras.models.save_model(keras_model, keras_file) + converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) + tflite_model = converter.convert() + open("firebase_mlkit_model.tflite", "wb").write(tflite_model) + return from_tflite_model_file("firebase_mlkit_model.tflite", bucket_name, app) + @property def gcs_tflite_uri(self): return self._gcs_tflite_uri diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 3d330c48b..f2e946631 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -333,7 +333,8 @@ def instrument_mlkit_service(status=200, payload=None, operations=False, app=Non class _TestStorageClient(object): @staticmethod def upload(bucket_name, model_file_name, app): - pass + blob_name = mlkit._CloudStorageClient.BLOB_NAME.format(model_file_name) + return mlkit._CloudStorageClient.GCS_URI.format(bucket_name, blob_name) @staticmethod def sign_uri(gcs_tflite_uri, app): @@ -418,6 +419,13 @@ def test_model_format_source_creation(self): } } + def test_source_creation_from_tflite_file(self): + model_source = mlkit.TFLiteGCSModelSource.from_tflite_model_file( + "my_model.tflite", "my_bucket") + assert model_source.as_dict() == { + 'gcsTfliteUri': 'gs://my_bucket/Firebase/MLKit/Models/my_model.tflite' + } + def test_model_source_setters(self): model_source = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI) model_source.gcs_tflite_uri = GCS_TFLITE_URI_2 From dab51c9dc8aa91dc2c441e4c57111b58d1d395b9 Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 17 Sep 2019 13:21:28 -0400 Subject: [PATCH 5/9] fixed --- firebase_admin/mlkit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index c2141e10a..ef2089f4b 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -408,7 +408,7 @@ def upload(bucket_name, model_file_name, app): blob_name = _CloudStorageClient.BLOB_NAME.format(model_file_name) blob = bucket.blob(blob_name) blob.upload_from_filename(model_file_name) - return GCS_URI.format(bucket.name, blob_name) + return _CloudStorageClient.GCS_URI.format(bucket.name, blob_name) @staticmethod def sign_uri(gcs_tflite_uri, app): @@ -486,7 +486,7 @@ def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None): converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) tflite_model = converter.convert() open("firebase_mlkit_model.tflite", "wb").write(tflite_model) - return from_tflite_model_file("firebase_mlkit_model.tflite", bucket_name, app) + return TFLiteGCSModelSource.from_tflite_model_file("firebase_mlkit_model.tflite", bucket_name, app) @classmethod def from_keras_model(cls, keras_model, bucket_name=None, app=None): @@ -513,7 +513,7 @@ def from_keras_model(cls, keras_model, bucket_name=None, app=None): converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) tflite_model = converter.convert() open("firebase_mlkit_model.tflite", "wb").write(tflite_model) - return from_tflite_model_file("firebase_mlkit_model.tflite", bucket_name, app) + return TFLiteGCSModelSource.from_tflite_model_file("firebase_mlkit_model.tflite", bucket_name, app) @property def gcs_tflite_uri(self): From a9450e51f232146516317aefd57183e437f4f6fd Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 17 Sep 2019 13:27:24 -0400 Subject: [PATCH 6/9] fixed --- firebase_admin/mlkit.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index ef2089f4b..21b5b5700 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -486,7 +486,8 @@ def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None): converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) tflite_model = converter.convert() open("firebase_mlkit_model.tflite", "wb").write(tflite_model) - return TFLiteGCSModelSource.from_tflite_model_file("firebase_mlkit_model.tflite", bucket_name, app) + return TFLiteGCSModelSource.from_tflite_model_file( + "firebase_mlkit_model.tflite", bucket_name, app) @classmethod def from_keras_model(cls, keras_model, bucket_name=None, app=None): @@ -513,7 +514,8 @@ def from_keras_model(cls, keras_model, bucket_name=None, app=None): converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) tflite_model = converter.convert() open("firebase_mlkit_model.tflite", "wb").write(tflite_model) - return TFLiteGCSModelSource.from_tflite_model_file("firebase_mlkit_model.tflite", bucket_name, app) + return TFLiteGCSModelSource.from_tflite_model_file( + "firebase_mlkit_model.tflite", bucket_name, app) @property def gcs_tflite_uri(self): From 8f1214e0c5e472529d0978646d73d9027669218c Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 17 Sep 2019 13:35:44 -0400 Subject: [PATCH 7/9] fixed --- tests/test_mlkit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index f2e946631..5ac68bc39 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -21,7 +21,6 @@ import firebase_admin from firebase_admin import exceptions from firebase_admin import mlkit -from firebase_admin import storage from tests import testutils BASE_URL = 'https://mlkit.googleapis.com/v1beta1/' @@ -333,11 +332,13 @@ def instrument_mlkit_service(status=200, payload=None, operations=False, app=Non class _TestStorageClient(object): @staticmethod def upload(bucket_name, model_file_name, app): + del app # unused variable blob_name = mlkit._CloudStorageClient.BLOB_NAME.format(model_file_name) return mlkit._CloudStorageClient.GCS_URI.format(bucket_name, blob_name) @staticmethod def sign_uri(gcs_tflite_uri, app): + del gcs_tflite_uri, app # unused variables return GCS_TFLITE_SIGNED_URI class TestModel(object): From 529304c3790b8890fe36697d0aae6e844481c373 Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 17 Sep 2019 16:36:57 -0400 Subject: [PATCH 8/9] review comments --- firebase_admin/mlkit.py | 60 ++++++++++++++++++++++------------------- tests/test_mlkit.py | 24 ++++++++++++----- 2 files changed, 50 insertions(+), 34 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 21b5b5700..88821fb4a 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -400,10 +400,23 @@ class _CloudStorageClient(object): BLOB_NAME = 'Firebase/MLKit/Models/{0}' @staticmethod - def upload(bucket_name, model_file_name, app): + def _assert_gcs_enabled(): if not GCS_ENABLED: raise ImportError('Failed to import the Cloud Storage library for Python. Make sure ' 'to install the "google-cloud-storage" module.') + + @staticmethod + def _parse_gcs_tflite_uri(uri): + # GCS Bucket naming rules are complex. The regex is not comprehensive. + # See https://cloud.google.com/storage/docs/naming for full details. + matcher = _GCS_TFLITE_URI_PATTERN.match(uri) + if not matcher: + raise ValueError('GCS TFLite URI format is invalid.') + return matcher.group('bucket_name'), matcher.group('blob_name') + + @staticmethod + def upload(bucket_name, model_file_name, app): + _CloudStorageClient._assert_gcs_enabled() bucket = storage.bucket(bucket_name, app=app) blob_name = _CloudStorageClient.BLOB_NAME.format(model_file_name) blob = bucket.blob(blob_name) @@ -412,11 +425,9 @@ def upload(bucket_name, model_file_name, app): @staticmethod def sign_uri(gcs_tflite_uri, app): - """Makes the gcs_tflite_uri readable for GET for 10 minutes.""" - if not GCS_ENABLED: - raise ImportError('Failed to import the Cloud Storage library for Python. Make sure ' - 'to install the "google-cloud-storage" module.') - bucket_name, blob_name = _parse_gcs_tflite_uri(gcs_tflite_uri) + """Makes the gcs_tflite_uri readable for GET for 10 minutes via signed_uri.""" + _CloudStorageClient._assert_gcs_enabled() + bucket_name, blob_name = _CloudStorageClient._parse_gcs_tflite_uri(gcs_tflite_uri) bucket = storage.bucket(bucket_name, app=app) blob = bucket.blob(blob_name) return blob.generate_signed_url( @@ -446,7 +457,7 @@ def __ne__(self, other): @classmethod def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None): - """Uploads the model file to an existing Google Cloud Services bucket. + """Uploads the model file to an existing Google Cloud Storage bucket. Args: model_file_name: The name of the model file. @@ -463,14 +474,22 @@ def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None): gcs_uri = TFLiteGCSModelSource._STORAGE_CLIENT.upload(bucket_name, model_file_name, app) return TFLiteGCSModelSource(gcs_tflite_uri=gcs_uri, app=app) + @staticmethod + def _assert_tf_version_1_enabled(): + if not TF_ENABLED: + raise ImportError('Failed to import the tensorflow library for Python. Make sure ' + 'to install the tensorflow module.') + if not tf.VERSION.startswith('1.'): + raise ImportError('Expected tensorflow version 1') + @classmethod def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None): """Creates a Tensor Flow Lite model from the saved model, and uploads the model to GCS. Args: saved_model_dir: The saved model directory. - bucket_name: Optional. The name of the bucket to store the uploaded tflite file. - (or None to use the default bucket) + bucket_name: The name of an existing bucket. None to use the default bucket configured + in the app. app: Optional. A Firebase app instance (or None to use the default app) Returns: @@ -479,10 +498,7 @@ def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None): Raises: ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed. """ - if not TF_ENABLED: - raise ImportError('Failed to import the tensorflow library for Python. Make sure ' - 'to install the tensorflow module.') - #TODO(ifielker): Do we need to worry about tf version? + TFLiteGCSModelSource._assert_tf_version_1_enabled() converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) tflite_model = converter.convert() open("firebase_mlkit_model.tflite", "wb").write(tflite_model) @@ -495,8 +511,8 @@ def from_keras_model(cls, keras_model, bucket_name=None, app=None): Args: keras_model: A tf.keras model. - bucket_name: Optional. The name of the bucket to store the uploaded tflite file. - (or None to use the default bucket) + bucket_name: The name of an existing bucket. None to use the default bucket configured + in the app. app: Optional. A Firebase app instance (or None to use the default app) Returns: @@ -505,10 +521,7 @@ def from_keras_model(cls, keras_model, bucket_name=None, app=None): Raises: ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed. """ - if not TF_ENABLED: - raise ImportError('Failed to import the tensorflow library for Python. Make sure ' - 'to install the tensorflow module.') - #TODO(ifielker): Do we need to worry about tf version? + TFLiteGCSModelSource._assert_tf_version_1_enabled() keras_file = "keras_model.h5" tf.keras.models.save_model(keras_model, keras_file) converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) @@ -687,15 +700,6 @@ def _validate_gcs_tflite_uri(uri): return uri -def _parse_gcs_tflite_uri(uri): - # GCS Bucket naming rules are complex. The regex is not comprehensive. - # See https://cloud.google.com/storage/docs/naming for full details. - matcher = _GCS_TFLITE_URI_PATTERN.match(uri) - if not matcher: - raise ValueError('GCS TFLite URI format is invalid.') - return matcher.group('bucket_name'), matcher.group('blob_name') - - def _validate_model_format(model_format): if not isinstance(model_format, ModelFormat): raise TypeError('Model format must be a ModelFormat object.') diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 5ac68bc39..bd359e015 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -103,7 +103,9 @@ } } -GCS_TFLITE_URI = 'gs://my_bucket/mymodel.tflite' +GCS_BUCKET_NAME = 'my_bucket' +GCS_BLOB_NAME = 'mymodel.tflite' +GCS_TFLITE_URI = 'gs://{0}/{1}'.format(GCS_BUCKET_NAME, GCS_BLOB_NAME) GCS_TFLITE_URI_JSON = {'gcsTfliteUri': GCS_TFLITE_URI} GCS_TFLITE_MODEL_SOURCE = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI) TFLITE_FORMAT_JSON = { @@ -112,9 +114,9 @@ } TFLITE_FORMAT = mlkit.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON) -GCS_TFLITE_SIGNED_URI = 'gs://test_bucket/test_blob?signing_information' -GCS_TFLITE_SIGNED_URI_JSON = {'gcsTfliteUri': GCS_TFLITE_URI} -GCS_TFLITE_SIGNED_MODEL_SOURCE = mlkit.TFLiteGCSModelSource(GCS_TFLITE_SIGNED_URI) +GCS_TFLITE_SIGNED_URI_PATTERN = ( + 'https://storage.googleapis.com/{0}/{1}?X-Goog-Algorithm=GOOG4-RSA-SHA256&foo') +GCS_TFLITE_SIGNED_URI = GCS_TFLITE_SIGNED_URI_PATTERN.format(GCS_BUCKET_NAME, GCS_BLOB_NAME) GCS_TFLITE_URI_2 = 'gs://my_bucket/mymodel2.tflite' GCS_TFLITE_URI_JSON_2 = {'gcsTfliteUri': GCS_TFLITE_URI_2} @@ -338,8 +340,9 @@ def upload(bucket_name, model_file_name, app): @staticmethod def sign_uri(gcs_tflite_uri, app): - del gcs_tflite_uri, app # unused variables - return GCS_TFLITE_SIGNED_URI + del app # unused variable + bucket_name, blob_name = mlkit._CloudStorageClient._parse_gcs_tflite_uri(gcs_tflite_uri) + return GCS_TFLITE_SIGNED_URI_PATTERN.format(bucket_name, blob_name) class TestModel(object): """Tests mlkit.Model class.""" @@ -454,6 +457,15 @@ def test_model_as_dict_for_upload(self): } } + @pytest.mark.parametrize('helper_func', [ + mlkit.TFLiteGCSModelSource.from_keras_model, + mlkit.TFLiteGCSModelSource.from_saved_model + ]) + def test_tf_not_enabled(self, helper_func): + with pytest.raises(ImportError) as excinfo: + helper_func(None) + check_error(excinfo, ImportError) + @pytest.mark.parametrize('display_name, exc_type', [ ('', ValueError), ('&_*#@:/?', ValueError), From 21f29f868d4308342dd132dff4c367c1ffaa4b30 Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 17 Sep 2019 17:26:56 -0400 Subject: [PATCH 9/9] review comments --- firebase_admin/mlkit.py | 32 ++++++++++++++------------------ tests/test_mlkit.py | 1 + 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 88821fb4a..8e78a26ce 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -19,7 +19,6 @@ """ - import datetime import numbers import re @@ -35,16 +34,16 @@ # pylint: disable=import-error,no-name-in-module try: from firebase_admin import storage - GCS_ENABLED = True + _GCS_ENABLED = True except ImportError: - GCS_ENABLED = False + _GCS_ENABLED = False # pylint: disable=import-error,no-name-in-module try: import tensorflow as tf - TF_ENABLED = True + _TF_ENABLED = True except ImportError: - TF_ENABLED = False + _TF_ENABLED = False _MLKIT_ATTRIBUTE = '_mlkit' _MAX_PAGE_SIZE = 100 @@ -401,7 +400,7 @@ class _CloudStorageClient(object): @staticmethod def _assert_gcs_enabled(): - if not GCS_ENABLED: + if not _GCS_ENABLED: raise ImportError('Failed to import the Cloud Storage library for Python. Make sure ' 'to install the "google-cloud-storage" module.') @@ -476,11 +475,11 @@ def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None): @staticmethod def _assert_tf_version_1_enabled(): - if not TF_ENABLED: + if not _TF_ENABLED: raise ImportError('Failed to import the tensorflow library for Python. Make sure ' 'to install the tensorflow module.') if not tf.VERSION.startswith('1.'): - raise ImportError('Expected tensorflow version 1') + raise ImportError('Expected tensorflow version 1.x, but found {0}'.format(tf.VERSION)) @classmethod def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None): @@ -501,9 +500,9 @@ def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None): TFLiteGCSModelSource._assert_tf_version_1_enabled() converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) tflite_model = converter.convert() - open("firebase_mlkit_model.tflite", "wb").write(tflite_model) + open('firebase_mlkit_model.tflite', 'wb').write(tflite_model) return TFLiteGCSModelSource.from_tflite_model_file( - "firebase_mlkit_model.tflite", bucket_name, app) + 'firebase_mlkit_model.tflite', bucket_name, app) @classmethod def from_keras_model(cls, keras_model, bucket_name=None, app=None): @@ -522,13 +521,13 @@ def from_keras_model(cls, keras_model, bucket_name=None, app=None): ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed. """ TFLiteGCSModelSource._assert_tf_version_1_enabled() - keras_file = "keras_model.h5" + keras_file = 'keras_model.h5' tf.keras.models.save_model(keras_model, keras_file) converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) tflite_model = converter.convert() - open("firebase_mlkit_model.tflite", "wb").write(tflite_model) + open('firebase_mlkit_model.tflite', 'wb').write(tflite_model) return TFLiteGCSModelSource.from_tflite_model_file( - "firebase_mlkit_model.tflite", bucket_name, app) + 'firebase_mlkit_model.tflite', bucket_name, app) @property def gcs_tflite_uri(self): @@ -544,12 +543,9 @@ def _get_signed_gcs_tflite_uri(self): def as_dict(self, for_upload=False): if for_upload: - return {"gcsTfliteUri": self._get_signed_gcs_tflite_uri()} - - return {"gcsTfliteUri": self._gcs_tflite_uri} - + return {'gcsTfliteUri': self._get_signed_gcs_tflite_uri()} - #TODO(ifielker): implement from_saved_model etc. + return {'gcsTfliteUri': self._gcs_tflite_uri} class ListModelsPage(object): diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index bd359e015..26afdfa99 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -462,6 +462,7 @@ def test_model_as_dict_for_upload(self): mlkit.TFLiteGCSModelSource.from_saved_model ]) def test_tf_not_enabled(self, helper_func): + mlkit._TF_ENABLED = False # for reliability with pytest.raises(ImportError) as excinfo: helper_func(None) check_error(excinfo, ImportError)