Skip to content

Commit 74d11d8

Browse files
authored
Adding support for TensorFlow 2.x (#372)
Adding support for TensorFlow 2.x and moving some payloads into parameters
1 parent 8939ed6 commit 74d11d8

File tree

2 files changed

+56
-31
lines changed

2 files changed

+56
-31
lines changed

firebase_admin/mlkit.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import six
2828

2929

30+
from six.moves import urllib
3031
from firebase_admin import _http_client
3132
from firebase_admin import _utils
3233
from firebase_admin import exceptions
@@ -200,6 +201,7 @@ def from_dict(cls, data, app=None):
200201
data_copy = dict(data)
201202
tflite_format = None
202203
tflite_format_data = data_copy.pop('tfliteModel', None)
204+
data_copy.pop('@type', None) # Returned by Operations. (Not needed)
203205
if tflite_format_data:
204206
tflite_format = TFLiteFormat.from_dict(tflite_format_data)
205207
model = Model(model_format=tflite_format)
@@ -495,12 +497,31 @@ def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None):
495497
return TFLiteGCSModelSource(gcs_tflite_uri=gcs_uri, app=app)
496498

497499
@staticmethod
498-
def _assert_tf_version_1_enabled():
500+
def _assert_tf_enabled():
499501
if not _TF_ENABLED:
500502
raise ImportError('Failed to import the tensorflow library for Python. Make sure '
501503
'to install the tensorflow module.')
502-
if not tf.VERSION.startswith('1.'):
503-
raise ImportError('Expected tensorflow version 1.x, but found {0}'.format(tf.VERSION))
504+
if not tf.version.VERSION.startswith('1.') and not tf.version.VERSION.startswith('2.'):
505+
raise ImportError('Expected tensorflow version 1.x or 2.x, but found {0}'
506+
.format(tf.version.VERSION))
507+
508+
@staticmethod
509+
def _tf_convert_from_saved_model(saved_model_dir):
510+
# Same for both v1.x and v2.x
511+
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
512+
return converter.convert()
513+
514+
@staticmethod
515+
def _tf_convert_from_keras_model(keras_model):
516+
# Version 1.x conversion function takes a model file. Version 2.x takes the model itself.
517+
if tf.version.VERSION.startswith('1.'):
518+
keras_file = 'firebase_keras_model.h5'
519+
tf.keras.models.save_model(keras_model, keras_file)
520+
converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)
521+
return converter.convert()
522+
else:
523+
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
524+
return converter.convert()
504525

505526
@classmethod
506527
def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None):
@@ -518,9 +539,8 @@ def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None):
518539
Raises:
519540
ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed.
520541
"""
521-
TFLiteGCSModelSource._assert_tf_version_1_enabled()
522-
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
523-
tflite_model = converter.convert()
542+
TFLiteGCSModelSource._assert_tf_enabled()
543+
tflite_model = TFLiteGCSModelSource._tf_convert_from_saved_model(saved_model_dir)
524544
open('firebase_mlkit_model.tflite', 'wb').write(tflite_model)
525545
return TFLiteGCSModelSource.from_tflite_model_file(
526546
'firebase_mlkit_model.tflite', bucket_name, app)
@@ -541,11 +561,8 @@ def from_keras_model(cls, keras_model, bucket_name=None, app=None):
541561
Raises:
542562
ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed.
543563
"""
544-
TFLiteGCSModelSource._assert_tf_version_1_enabled()
545-
keras_file = 'keras_model.h5'
546-
tf.keras.models.save_model(keras_model, keras_file)
547-
converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)
548-
tflite_model = converter.convert()
564+
TFLiteGCSModelSource._assert_tf_enabled()
565+
tflite_model = TFLiteGCSModelSource._tf_convert_from_keras_model(keras_model)
549566
open('firebase_mlkit_model.tflite', 'wb').write(tflite_model)
550567
return TFLiteGCSModelSource.from_tflite_model_file(
551568
'firebase_mlkit_model.tflite', bucket_name, app)
@@ -852,12 +869,12 @@ def create_model(self, model):
852869

853870
def update_model(self, model, update_mask=None):
854871
_validate_model(model, update_mask)
855-
data = {'model': model.as_dict(for_upload=True)}
872+
path = 'models/{0}'.format(model.model_id)
856873
if update_mask is not None:
857-
data['updateMask'] = update_mask
874+
path = path + '?updateMask={0}'.format(update_mask)
858875
try:
859876
return self.handle_operation(
860-
self._client.body('patch', url='models/{0}'.format(model.model_id), json=data))
877+
self._client.body('patch', url=path, json=model.as_dict(for_upload=True)))
861878
except requests.exceptions.RequestException as error:
862879
raise _utils.handle_platform_error_from_requests(error)
863880

@@ -884,15 +901,20 @@ def list_models(self, list_filter, page_size, page_token):
884901
_validate_list_filter(list_filter)
885902
_validate_page_size(page_size)
886903
_validate_page_token(page_token)
887-
payload = {}
904+
params = {}
888905
if list_filter:
889-
payload['list_filter'] = list_filter
906+
params['filter'] = list_filter
890907
if page_size:
891-
payload['page_size'] = page_size
908+
params['page_size'] = page_size
892909
if page_token:
893-
payload['page_token'] = page_token
910+
params['page_token'] = page_token
911+
path = 'models'
912+
if params:
913+
# pylint: disable=too-many-function-args
914+
param_str = urllib.parse.urlencode(sorted(params.items()), True)
915+
path = path + '?' + param_str
894916
try:
895-
return self._client.body('get', url='models', json=payload)
917+
return self._client.body('get', url=path)
896918
except requests.exceptions.RequestException as error:
897919
raise _utils.handle_platform_error_from_requests(error)
898920

tests/test_mlkit.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -763,7 +763,13 @@ def teardown_class(cls):
763763
testutils.cleanup_apps()
764764

765765
@staticmethod
766-
def _url(project_id, model_id):
766+
def _update_url(project_id, model_id):
767+
update_url = 'projects/{0}/models/{1}?updateMask=state.published'.format(
768+
project_id, model_id)
769+
return BASE_URL + update_url
770+
771+
@staticmethod
772+
def _get_url(project_id, model_id):
767773
return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id)
768774

769775
@staticmethod
@@ -778,10 +784,9 @@ def test_immediate_done(self, publish_function, published):
778784
assert model == CREATED_UPDATED_MODEL_1
779785
assert len(recorder) == 1
780786
assert recorder[0].method == 'PATCH'
781-
assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
787+
assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1)
782788
body = json.loads(recorder[0].body.decode())
783-
assert body.get('model', {}).get('state', {}).get('published', None) is published
784-
assert body.get('updateMask', {}) == 'state.published'
789+
assert body.get('state', {}).get('published', None) is published
785790

786791
@pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS)
787792
def test_returns_locked(self, publish_function):
@@ -794,9 +799,9 @@ def test_returns_locked(self, publish_function):
794799
assert model == expected_model
795800
assert len(recorder) == 2
796801
assert recorder[0].method == 'PATCH'
797-
assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
802+
assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1)
798803
assert recorder[1].method == 'GET'
799-
assert recorder[1].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
804+
assert recorder[1].url == TestPublishUnpublish._get_url(PROJECT_ID, MODEL_ID_1)
800805

801806
@pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS)
802807
def test_operation_error(self, publish_function):
@@ -973,12 +978,10 @@ def test_list_models_with_all_args(self):
973978
page_token=PAGE_TOKEN)
974979
assert len(recorder) == 1
975980
assert recorder[0].method == 'GET'
976-
assert recorder[0].url == TestListModels._url(PROJECT_ID)
977-
assert json.loads(recorder[0].body.decode()) == {
978-
'list_filter': 'display_name=displayName3',
979-
'page_size': 10,
980-
'page_token': PAGE_TOKEN
981-
}
981+
assert recorder[0].url == (
982+
TestListModels._url(PROJECT_ID) +
983+
'?filter=display_name%3DdisplayName3&page_size=10&page_token={0}'
984+
.format(PAGE_TOKEN))
982985
assert isinstance(models_page, mlkit.ListModelsPage)
983986
assert len(models_page.models) == 1
984987
assert models_page.models[0] == MODEL_3

0 commit comments

Comments
 (0)