Skip to content

Commit cd5e82a

Browse files
authored
Firebase ML Kit TFLiteGCSModelSource.from_tflite_model implementation and conversion helpers (#346)
* Firebase ML Kit TFLiteGCSModelSource.from_tflite_model implementation * support for tensorflow lite conversion helpers (version 1.x)
1 parent 0344172 commit cd5e82a

File tree

2 files changed

+205
-14
lines changed

2 files changed

+205
-14
lines changed

firebase_admin/mlkit.py

Lines changed: 156 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
deleting, publishing and unpublishing Firebase ML Kit models.
1919
"""
2020

21+
2122
import datetime
2223
import numbers
2324
import re
@@ -30,13 +31,27 @@
3031
from firebase_admin import _utils
3132
from firebase_admin import exceptions
3233

34+
# pylint: disable=import-error,no-name-in-module
35+
try:
36+
from firebase_admin import storage
37+
_GCS_ENABLED = True
38+
except ImportError:
39+
_GCS_ENABLED = False
40+
41+
# pylint: disable=import-error,no-name-in-module
42+
try:
43+
import tensorflow as tf
44+
_TF_ENABLED = True
45+
except ImportError:
46+
_TF_ENABLED = False
3347

3448
_MLKIT_ATTRIBUTE = '_mlkit'
3549
_MAX_PAGE_SIZE = 100
3650
_MODEL_ID_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$')
3751
_DISPLAY_NAME_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$')
3852
_TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$')
39-
_GCS_TFLITE_URI_PATTERN = re.compile(r'^gs://[a-z0-9_.-]{3,63}/.+')
53+
_GCS_TFLITE_URI_PATTERN = re.compile(
54+
r'^gs://(?P<bucket_name>[a-z0-9_.-]{3,63})/(?P<blob_name>.+)$')
4055
_RESOURCE_NAME_PATTERN = re.compile(
4156
r'^projects/(?P<project_id>[^/]+)/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$')
4257
_OPERATION_NAME_PATTERN = re.compile(
@@ -301,16 +316,16 @@ def model_format(self, model_format):
301316
self._model_format = model_format #Can be None
302317
return self
303318

304-
def as_dict(self):
319+
def as_dict(self, for_upload=False):
305320
copy = dict(self._data)
306321
if self._model_format:
307-
copy.update(self._model_format.as_dict())
322+
copy.update(self._model_format.as_dict(for_upload=for_upload))
308323
return copy
309324

310325

311326
class ModelFormat(object):
312327
"""Abstract base class representing a Model Format such as TFLite."""
313-
def as_dict(self):
328+
def as_dict(self, for_upload=False):
314329
raise NotImplementedError
315330

316331

@@ -364,22 +379,70 @@ def model_source(self, model_source):
364379
def size_bytes(self):
365380
return self._data.get('sizeBytes')
366381

367-
def as_dict(self):
382+
def as_dict(self, for_upload=False):
368383
copy = dict(self._data)
369384
if self._model_source:
370-
copy.update(self._model_source.as_dict())
385+
copy.update(self._model_source.as_dict(for_upload=for_upload))
371386
return {'tfliteModel': copy}
372387

373388

374389
class TFLiteModelSource(object):
375390
"""Abstract base class representing a model source for TFLite format models."""
376-
def as_dict(self):
391+
def as_dict(self, for_upload=False):
377392
raise NotImplementedError
378393

379394

395+
class _CloudStorageClient(object):
396+
"""Cloud Storage helper class"""
397+
398+
GCS_URI = 'gs://{0}/{1}'
399+
BLOB_NAME = 'Firebase/MLKit/Models/{0}'
400+
401+
@staticmethod
402+
def _assert_gcs_enabled():
403+
if not _GCS_ENABLED:
404+
raise ImportError('Failed to import the Cloud Storage library for Python. Make sure '
405+
'to install the "google-cloud-storage" module.')
406+
407+
@staticmethod
408+
def _parse_gcs_tflite_uri(uri):
409+
# GCS Bucket naming rules are complex. The regex is not comprehensive.
410+
# See https://cloud.google.com/storage/docs/naming for full details.
411+
matcher = _GCS_TFLITE_URI_PATTERN.match(uri)
412+
if not matcher:
413+
raise ValueError('GCS TFLite URI format is invalid.')
414+
return matcher.group('bucket_name'), matcher.group('blob_name')
415+
416+
@staticmethod
417+
def upload(bucket_name, model_file_name, app):
418+
_CloudStorageClient._assert_gcs_enabled()
419+
bucket = storage.bucket(bucket_name, app=app)
420+
blob_name = _CloudStorageClient.BLOB_NAME.format(model_file_name)
421+
blob = bucket.blob(blob_name)
422+
blob.upload_from_filename(model_file_name)
423+
return _CloudStorageClient.GCS_URI.format(bucket.name, blob_name)
424+
425+
@staticmethod
426+
def sign_uri(gcs_tflite_uri, app):
427+
"""Makes the gcs_tflite_uri readable for GET for 10 minutes via signed_uri."""
428+
_CloudStorageClient._assert_gcs_enabled()
429+
bucket_name, blob_name = _CloudStorageClient._parse_gcs_tflite_uri(gcs_tflite_uri)
430+
bucket = storage.bucket(bucket_name, app=app)
431+
blob = bucket.blob(blob_name)
432+
return blob.generate_signed_url(
433+
version='v4',
434+
expiration=datetime.timedelta(minutes=10),
435+
method='GET'
436+
)
437+
438+
380439
class TFLiteGCSModelSource(TFLiteModelSource):
381440
"""TFLite model source representing a tflite model file stored in GCS."""
382-
def __init__(self, gcs_tflite_uri):
441+
442+
_STORAGE_CLIENT = _CloudStorageClient()
443+
444+
def __init__(self, gcs_tflite_uri, app=None):
445+
self._app = app
383446
self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri)
384447

385448
def __eq__(self, other):
@@ -391,6 +454,81 @@ def __eq__(self, other):
391454
def __ne__(self, other):
392455
return not self.__eq__(other)
393456

457+
@classmethod
458+
def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None):
459+
"""Uploads the model file to an existing Google Cloud Storage bucket.
460+
461+
Args:
462+
model_file_name: The name of the model file.
463+
bucket_name: The name of an existing bucket. None to use the default bucket configured
464+
in the app.
465+
app: A Firebase app instance (or None to use the default app).
466+
467+
Returns:
468+
TFLiteGCSModelSource: The source created from the model_file
469+
470+
Raises:
471+
ImportError: If the Cloud Storage Library has not been installed.
472+
"""
473+
gcs_uri = TFLiteGCSModelSource._STORAGE_CLIENT.upload(bucket_name, model_file_name, app)
474+
return TFLiteGCSModelSource(gcs_tflite_uri=gcs_uri, app=app)
475+
476+
@staticmethod
477+
def _assert_tf_version_1_enabled():
478+
if not _TF_ENABLED:
479+
raise ImportError('Failed to import the tensorflow library for Python. Make sure '
480+
'to install the tensorflow module.')
481+
if not tf.VERSION.startswith('1.'):
482+
raise ImportError('Expected tensorflow version 1.x, but found {0}'.format(tf.VERSION))
483+
484+
@classmethod
485+
def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None):
486+
"""Creates a Tensor Flow Lite model from the saved model, and uploads the model to GCS.
487+
488+
Args:
489+
saved_model_dir: The saved model directory.
490+
bucket_name: The name of an existing bucket. None to use the default bucket configured
491+
in the app.
492+
app: Optional. A Firebase app instance (or None to use the default app)
493+
494+
Returns:
495+
TFLiteGCSModelSource: The source created from the saved_model_dir
496+
497+
Raises:
498+
ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed.
499+
"""
500+
TFLiteGCSModelSource._assert_tf_version_1_enabled()
501+
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
502+
tflite_model = converter.convert()
503+
open('firebase_mlkit_model.tflite', 'wb').write(tflite_model)
504+
return TFLiteGCSModelSource.from_tflite_model_file(
505+
'firebase_mlkit_model.tflite', bucket_name, app)
506+
507+
@classmethod
508+
def from_keras_model(cls, keras_model, bucket_name=None, app=None):
509+
"""Creates a Tensor Flow Lite model from the keras model, and uploads the model to GCS.
510+
511+
Args:
512+
keras_model: A tf.keras model.
513+
bucket_name: The name of an existing bucket. None to use the default bucket configured
514+
in the app.
515+
app: Optional. A Firebase app instance (or None to use the default app)
516+
517+
Returns:
518+
TFLiteGCSModelSource: The source created from the keras_model
519+
520+
Raises:
521+
ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed.
522+
"""
523+
TFLiteGCSModelSource._assert_tf_version_1_enabled()
524+
keras_file = 'keras_model.h5'
525+
tf.keras.models.save_model(keras_model, keras_file)
526+
converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)
527+
tflite_model = converter.convert()
528+
open('firebase_mlkit_model.tflite', 'wb').write(tflite_model)
529+
return TFLiteGCSModelSource.from_tflite_model_file(
530+
'firebase_mlkit_model.tflite', bucket_name, app)
531+
394532
@property
395533
def gcs_tflite_uri(self):
396534
return self._gcs_tflite_uri
@@ -399,10 +537,15 @@ def gcs_tflite_uri(self):
399537
def gcs_tflite_uri(self, gcs_tflite_uri):
400538
self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri)
401539

402-
def as_dict(self):
403-
return {"gcsTfliteUri": self._gcs_tflite_uri}
540+
def _get_signed_gcs_tflite_uri(self):
541+
"""Signs the GCS uri, so the model file can be uploaded to Firebase ML Kit and verified."""
542+
return TFLiteGCSModelSource._STORAGE_CLIENT.sign_uri(self._gcs_tflite_uri, self._app)
543+
544+
def as_dict(self, for_upload=False):
545+
if for_upload:
546+
return {'gcsTfliteUri': self._get_signed_gcs_tflite_uri()}
404547

405-
#TODO(ifielker): implement from_saved_model etc.
548+
return {'gcsTfliteUri': self._gcs_tflite_uri}
406549

407550

408551
class ListModelsPage(object):
@@ -671,13 +814,13 @@ def create_model(self, model):
671814
_validate_model(model)
672815
try:
673816
return self.handle_operation(
674-
self._client.body('post', url='models', json=model.as_dict()))
817+
self._client.body('post', url='models', json=model.as_dict(for_upload=True)))
675818
except requests.exceptions.RequestException as error:
676819
raise _utils.handle_platform_error_from_requests(error)
677820

678821
def update_model(self, model, update_mask=None):
679822
_validate_model(model, update_mask)
680-
data = {'model': model.as_dict()}
823+
data = {'model': model.as_dict(for_upload=True)}
681824
if update_mask is not None:
682825
data['updateMask'] = update_mask
683826
try:

tests/test_mlkit.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@
103103
}
104104
}
105105

106-
GCS_TFLITE_URI = 'gs://my_bucket/mymodel.tflite'
106+
GCS_BUCKET_NAME = 'my_bucket'
107+
GCS_BLOB_NAME = 'mymodel.tflite'
108+
GCS_TFLITE_URI = 'gs://{0}/{1}'.format(GCS_BUCKET_NAME, GCS_BLOB_NAME)
107109
GCS_TFLITE_URI_JSON = {'gcsTfliteUri': GCS_TFLITE_URI}
108110
GCS_TFLITE_MODEL_SOURCE = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI)
109111
TFLITE_FORMAT_JSON = {
@@ -112,6 +114,10 @@
112114
}
113115
TFLITE_FORMAT = mlkit.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON)
114116

117+
GCS_TFLITE_SIGNED_URI_PATTERN = (
118+
'https://storage.googleapis.com/{0}/{1}?X-Goog-Algorithm=GOOG4-RSA-SHA256&foo')
119+
GCS_TFLITE_SIGNED_URI = GCS_TFLITE_SIGNED_URI_PATTERN.format(GCS_BUCKET_NAME, GCS_BLOB_NAME)
120+
115121
GCS_TFLITE_URI_2 = 'gs://my_bucket/mymodel2.tflite'
116122
GCS_TFLITE_URI_JSON_2 = {'gcsTfliteUri': GCS_TFLITE_URI_2}
117123
GCS_TFLITE_MODEL_SOURCE_2 = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI_2)
@@ -325,6 +331,18 @@ def instrument_mlkit_service(status=200, payload=None, operations=False, app=Non
325331
session_url, adapter(payload, status, recorder))
326332
return recorder
327333

334+
class _TestStorageClient(object):
335+
@staticmethod
336+
def upload(bucket_name, model_file_name, app):
337+
del app # unused variable
338+
blob_name = mlkit._CloudStorageClient.BLOB_NAME.format(model_file_name)
339+
return mlkit._CloudStorageClient.GCS_URI.format(bucket_name, blob_name)
340+
341+
@staticmethod
342+
def sign_uri(gcs_tflite_uri, app):
343+
del app # unused variable
344+
bucket_name, blob_name = mlkit._CloudStorageClient._parse_gcs_tflite_uri(gcs_tflite_uri)
345+
return GCS_TFLITE_SIGNED_URI_PATTERN.format(bucket_name, blob_name)
328346

329347
class TestModel(object):
330348
"""Tests mlkit.Model class."""
@@ -333,6 +351,7 @@ def setup_class(cls):
333351
cred = testutils.MockCredential()
334352
firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID})
335353
mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test
354+
mlkit.TFLiteGCSModelSource._STORAGE_CLIENT = _TestStorageClient()
336355

337356
@classmethod
338357
def teardown_class(cls):
@@ -404,6 +423,13 @@ def test_model_format_source_creation(self):
404423
}
405424
}
406425

426+
def test_source_creation_from_tflite_file(self):
427+
model_source = mlkit.TFLiteGCSModelSource.from_tflite_model_file(
428+
"my_model.tflite", "my_bucket")
429+
assert model_source.as_dict() == {
430+
'gcsTfliteUri': 'gs://my_bucket/Firebase/MLKit/Models/my_model.tflite'
431+
}
432+
407433
def test_model_source_setters(self):
408434
model_source = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI)
409435
model_source.gcs_tflite_uri = GCS_TFLITE_URI_2
@@ -420,6 +446,27 @@ def test_model_format_setters(self):
420446
}
421447
}
422448

449+
def test_model_as_dict_for_upload(self):
450+
model_source = mlkit.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI)
451+
model_format = mlkit.TFLiteFormat(model_source=model_source)
452+
model = mlkit.Model(display_name=DISPLAY_NAME_1, model_format=model_format)
453+
assert model.as_dict(for_upload=True) == {
454+
'displayName': DISPLAY_NAME_1,
455+
'tfliteModel': {
456+
'gcsTfliteUri': GCS_TFLITE_SIGNED_URI
457+
}
458+
}
459+
460+
@pytest.mark.parametrize('helper_func', [
461+
mlkit.TFLiteGCSModelSource.from_keras_model,
462+
mlkit.TFLiteGCSModelSource.from_saved_model
463+
])
464+
def test_tf_not_enabled(self, helper_func):
465+
mlkit._TF_ENABLED = False # for reliability
466+
with pytest.raises(ImportError) as excinfo:
467+
helper_func(None)
468+
check_error(excinfo, ImportError)
469+
423470
@pytest.mark.parametrize('display_name, exc_type', [
424471
('', ValueError),
425472
('&_*#@:/?', ValueError),
@@ -803,6 +850,7 @@ def test_rpc_error(self, publish_function):
803850
)
804851
assert len(create_recorder) == 1
805852

853+
806854
class TestGetModel(object):
807855
"""Tests mlkit.get_model."""
808856
@classmethod

0 commit comments

Comments
 (0)