From f0aedda2c37feedf46599fe6935c6d4cfd568a50 Mon Sep 17 00:00:00 2001 From: ifielker Date: Wed, 27 Nov 2019 23:45:23 -0500 Subject: [PATCH 01/15] Firebase ML Kit Modify Operation Handling to not require a name for Done Operations --- firebase_admin/mlkit.py | 40 +++++++++++++++++++-------------- tests/test_mlkit.py | 49 +++++++++++------------------------------ 2 files changed, 37 insertions(+), 52 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index a135d2d7f..72a34395e 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -810,28 +810,36 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds """ if not isinstance(operation, dict): raise TypeError('Operation must be a dictionary.') - op_name = operation.get('name') - _, model_id = _validate_and_parse_operation_name(op_name) - - current_attempt = 0 - start_time = datetime.datetime.now() - stop_time = (None if max_time_seconds is None else - start_time + datetime.timedelta(seconds=max_time_seconds)) - while wait_for_operation and not operation.get('done'): - # We just got this operation. Wait before getting another - # so we don't exceed the GetOperation maximum request rate. - self._exponential_backoff(current_attempt, stop_time) - operation = self.get_operation(op_name) - current_attempt += 1 if operation.get('done'): + # Operations which are immediately done don't have an operation name if operation.get('response'): return operation.get('response') elif operation.get('error'): raise _utils.handle_operation_error(operation.get('error')) - - # If the operation is not complete or timed out, return a (locked) model instead - return get_model(model_id).as_dict() + raise exceptions.UnknownError(message='Internal Error: Malformed Operation.') + else: + op_name = operation.get('name') + _, model_id = _validate_and_parse_operation_name(op_name) + current_attempt = 0 + start_time = datetime.datetime.now() + stop_time = (None if max_time_seconds is None else + start_time + datetime.timedelta(seconds=max_time_seconds)) + while wait_for_operation and not operation.get('done'): + # We just got this operation. Wait before getting another + # so we don't exceed the GetOperation maximum request rate. + self._exponential_backoff(current_attempt, stop_time) + operation = self.get_operation(op_name) + current_attempt += 1 + + if operation.get('done'): + if operation.get('response'): + return operation.get('response') + elif operation.get('error'): + raise _utils.handle_operation_error(operation.get('error')) + + # If the operation is not complete or timed out, return a (locked) model instead + return get_model(model_id).as_dict() def create_model(self, model): diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 26afdfa99..fbe31aec4 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -158,23 +158,21 @@ } OPERATION_DONE_MODEL_JSON_1 = { - 'name': OPERATION_NAME_1, 'done': True, 'response': CREATED_UPDATED_MODEL_JSON_1 } OPERATION_MALFORMED_JSON_1 = { - 'name': OPERATION_NAME_1, 'done': True, # if done is true then either response or error should be populated } OPERATION_MISSING_NAME = { + # Name is required if the operation is not done. 'done': False } OPERATION_ERROR_CODE = 400 OPERATION_ERROR_MSG = "Invalid argument" OPERATION_ERROR_EXPECTED_STATUS = 'INVALID_ARGUMENT' OPERATION_ERROR_JSON_1 = { - 'name': OPERATION_NAME_1, 'done': True, 'error': { 'code': OPERATION_ERROR_CODE, @@ -609,17 +607,10 @@ def test_operation_error(self): check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) def test_malformed_operation(self): - recorder = instrument_mlkit_service( - status=[200, 200], - payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE]) - expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) - model = mlkit.create_model(MODEL_1) - assert model == expected_model - assert len(recorder) == 2 - assert recorder[0].method == 'POST' - assert recorder[0].url == TestCreateModel._url(PROJECT_ID) - assert recorder[1].method == 'GET' - assert recorder[1].url == TestCreateModel._get_url(PROJECT_ID, MODEL_ID_1) + instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) + with pytest.raises(Exception) as excinfo: + mlkit.create_model(MODEL_1) + check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.') def test_rpc_error_create(self): create_recorder = instrument_mlkit_service( @@ -708,17 +699,10 @@ def test_operation_error(self): check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) def test_malformed_operation(self): - recorder = instrument_mlkit_service( - status=[200, 200], - payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE]) - expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) - model = mlkit.update_model(MODEL_1) - assert model == expected_model - assert len(recorder) == 2 - assert recorder[0].method == 'PATCH' - assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) - assert recorder[1].method == 'GET' - assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) + instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) + with pytest.raises(Exception) as excinfo: + mlkit.update_model(MODEL_1) + check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.') def test_rpc_error(self): create_recorder = instrument_mlkit_service( @@ -824,17 +808,10 @@ def test_operation_error(self, publish_function): @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_malformed_operation(self, publish_function): - recorder = instrument_mlkit_service( - status=[200, 200], - payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE]) - expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) - model = publish_function(MODEL_ID_1) - assert model == expected_model - assert len(recorder) == 2 - assert recorder[0].method == 'PATCH' - assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) - assert recorder[1].method == 'GET' - assert recorder[1].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) + instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) + with pytest.raises(Exception) as excinfo: + publish_function(MODEL_ID_1) + check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.') @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_rpc_error(self, publish_function): From 8939ed60c991ce27c2d38bd57e1718e0a580b2e6 Mon Sep 17 00:00:00 2001 From: ifielker Date: Thu, 28 Nov 2019 00:00:25 -0500 Subject: [PATCH 02/15] fixed lint --- firebase_admin/mlkit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 72a34395e..da8bda3c8 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -824,7 +824,7 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds current_attempt = 0 start_time = datetime.datetime.now() stop_time = (None if max_time_seconds is None else - start_time + datetime.timedelta(seconds=max_time_seconds)) + start_time + datetime.timedelta(seconds=max_time_seconds)) while wait_for_operation and not operation.get('done'): # We just got this operation. Wait before getting another # so we don't exceed the GetOperation maximum request rate. From 7078cb6a61d1fc2d06a84803d117cc12b15bf56d Mon Sep 17 00:00:00 2001 From: ifielker Date: Mon, 2 Dec 2019 14:29:12 -0500 Subject: [PATCH 03/15] Adding support for TensorFlow 2.x --- firebase_admin/mlkit.py | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index da8bda3c8..5d5c47a47 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -495,12 +495,30 @@ def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None): return TFLiteGCSModelSource(gcs_tflite_uri=gcs_uri, app=app) @staticmethod - def _assert_tf_version_1_enabled(): + def _assert_tf_enabled(): if not _TF_ENABLED: raise ImportError('Failed to import the tensorflow library for Python. Make sure ' 'to install the tensorflow module.') - if not tf.VERSION.startswith('1.'): - raise ImportError('Expected tensorflow version 1.x, but found {0}'.format(tf.VERSION)) + if not tf.version.VERSION.startswith('1.') and not tf.Version.startswith('2.'): + raise ImportError('Expected tensorflow version 1.x or 2.x, but found {0}' + .format(tf.version.VERSION)) + + @staticmethod + def _tf_convert_from_saved_model(saved_model_dir): + # Same for both v1.x and v2.x + converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) + return converter.convert() + + @staticmethod + def _tf_convert_from_keras_model(keras_model): + if tf.version.VERSION.startswith('1.'): + keras_file = 'firebase_keras_model.h5' + tf.keras.models.save_model(keras_model, keras_file) + converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) + return converter.convert() + else: + converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) + return converter.convert() @classmethod def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None): @@ -518,9 +536,8 @@ def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None): Raises: ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed. """ - TFLiteGCSModelSource._assert_tf_version_1_enabled() - converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) - tflite_model = converter.convert() + TFLiteGCSModelSource._assert_tf_enabled() + tflite_model = TFLiteGCSModelSource._tf_convert_from_saved_model(saved_model_dir) open('firebase_mlkit_model.tflite', 'wb').write(tflite_model) return TFLiteGCSModelSource.from_tflite_model_file( 'firebase_mlkit_model.tflite', bucket_name, app) @@ -541,11 +558,8 @@ def from_keras_model(cls, keras_model, bucket_name=None, app=None): Raises: ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed. """ - TFLiteGCSModelSource._assert_tf_version_1_enabled() - keras_file = 'keras_model.h5' - tf.keras.models.save_model(keras_model, keras_file) - converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) - tflite_model = converter.convert() + TFLiteGCSModelSource._assert_tf_enabled() + tflite_model = TFLiteGCSModelSource._tf_convert_from_keras_model(keras_model) open('firebase_mlkit_model.tflite', 'wb').write(tflite_model) return TFLiteGCSModelSource.from_tflite_model_file( 'firebase_mlkit_model.tflite', bucket_name, app) From d88a66b4cb920861200d6258dccd6b830796efd2 Mon Sep 17 00:00:00 2001 From: ifielker Date: Mon, 2 Dec 2019 14:37:34 -0500 Subject: [PATCH 04/15] fix typo --- firebase_admin/mlkit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 5d5c47a47..580f29d4c 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -499,7 +499,7 @@ def _assert_tf_enabled(): if not _TF_ENABLED: raise ImportError('Failed to import the tensorflow library for Python. Make sure ' 'to install the tensorflow module.') - if not tf.version.VERSION.startswith('1.') and not tf.Version.startswith('2.'): + if not tf.version.VERSION.startswith('1.') and not tf.version.VERSION.startswith('2.'): raise ImportError('Expected tensorflow version 1.x or 2.x, but found {0}' .format(tf.version.VERSION)) From 8b5a6b34a01a58b16f800a5c6b116f0d96fd59cb Mon Sep 17 00:00:00 2001 From: ifielker Date: Mon, 2 Dec 2019 14:44:01 -0500 Subject: [PATCH 05/15] remove extraneous @type from operations --- firebase_admin/mlkit.py | 1 + 1 file changed, 1 insertion(+) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 580f29d4c..027786400 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -200,6 +200,7 @@ def from_dict(cls, data, app=None): data_copy = dict(data) tflite_format = None tflite_format_data = data_copy.pop('tfliteModel', None) + data_copy.pop('@type', None) # Returned by Operations. (Not needed) if tflite_format_data: tflite_format = TFLiteFormat.from_dict(tflite_format_data) model = Model(model_format=tflite_format) From 7ee369d9d573723f37af10c4a942b4d781110e93 Mon Sep 17 00:00:00 2001 From: ifielker Date: Wed, 4 Dec 2019 14:29:20 -0500 Subject: [PATCH 06/15] send updateMask in query parameter --- firebase_admin/mlkit.py | 6 +++--- tests/test_mlkit.py | 16 ++++++++++------ 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 027786400..b89fa5fdf 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -867,12 +867,12 @@ def create_model(self, model): def update_model(self, model, update_mask=None): _validate_model(model, update_mask) - data = {'model': model.as_dict(for_upload=True)} + path = 'models/{0}'.format(model.model_id) if update_mask is not None: - data['updateMask'] = update_mask + path = path + '?updateMask={0}'.format(update_mask) try: return self.handle_operation( - self._client.body('patch', url='models/{0}'.format(model.model_id), json=data)) + self._client.body('patch', url=path, json=model.as_dict(for_upload=True))) except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index fbe31aec4..8b0816b0a 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -763,7 +763,12 @@ def teardown_class(cls): testutils.cleanup_apps() @staticmethod - def _url(project_id, model_id): + def _update_url(project_id, model_id): + update_url = 'projects/{0}/models/{1}?updateMask=state.published' + return BASE_URL + update_url.format(project_id, model_id) + + @staticmethod + def _get_url(project_id, model_id): return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) @staticmethod @@ -778,10 +783,9 @@ def test_immediate_done(self, publish_function, published): assert model == CREATED_UPDATED_MODEL_1 assert len(recorder) == 1 assert recorder[0].method == 'PATCH' - assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1) body = json.loads(recorder[0].body.decode()) - assert body.get('model', {}).get('state', {}).get('published', None) is published - assert body.get('updateMask', {}) == 'state.published' + assert body.get('state', {}).get('published', None) is published @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_returns_locked(self, publish_function): @@ -794,9 +798,9 @@ def test_returns_locked(self, publish_function): assert model == expected_model assert len(recorder) == 2 assert recorder[0].method == 'PATCH' - assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1) assert recorder[1].method == 'GET' - assert recorder[1].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) + assert recorder[1].url == TestPublishUnpublish._get_url(PROJECT_ID, MODEL_ID_1) @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_operation_error(self, publish_function): From d1a993335fe8eceb22612487597b389176b69df5 Mon Sep 17 00:00:00 2001 From: ifielker Date: Wed, 4 Dec 2019 15:22:21 -0500 Subject: [PATCH 07/15] send list filters etc in query parameters --- firebase_admin/mlkit.py | 13 ++++++++----- tests/test_mlkit.py | 10 ++++------ 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index b89fa5fdf..6e204029d 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -899,15 +899,18 @@ def list_models(self, list_filter, page_size, page_token): _validate_list_filter(list_filter) _validate_page_size(page_size) _validate_page_token(page_token) - payload = {} + path = 'models' + joiner = '?' if list_filter: - payload['list_filter'] = list_filter + path = path + joiner + 'listFilter=\'{0}\''.format(list_filter) + joiner = '&' if page_size: - payload['page_size'] = page_size + path = path + joiner + 'pageSize={0}'.format(page_size) + joiner = '&' if page_token: - payload['page_token'] = page_token + path = path + joiner + 'pageToken={0}'.format(page_token) try: - return self._client.body('get', url='models', json=payload) + return self._client.body('get', url=path) except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 8b0816b0a..577e5b420 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -977,12 +977,10 @@ def test_list_models_with_all_args(self): page_token=PAGE_TOKEN) assert len(recorder) == 1 assert recorder[0].method == 'GET' - assert recorder[0].url == TestListModels._url(PROJECT_ID) - assert json.loads(recorder[0].body.decode()) == { - 'list_filter': 'display_name=displayName3', - 'page_size': 10, - 'page_token': PAGE_TOKEN - } + assert recorder[0].url == ( + TestListModels._url(PROJECT_ID) + + '?listFilter=\'display_name=displayName3\'&pageSize=10&pageToken={0}' + .format(PAGE_TOKEN)) assert isinstance(models_page, mlkit.ListModelsPage) assert len(models_page.models) == 1 assert models_page.models[0] == MODEL_3 From 68ebe7732f1f254ff9356c963cbf1687e5851a06 Mon Sep 17 00:00:00 2001 From: ifielker Date: Wed, 4 Dec 2019 17:16:14 -0500 Subject: [PATCH 08/15] fix typo --- firebase_admin/mlkit.py | 6 +++--- tests/test_mlkit.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 6e204029d..624a4000d 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -902,13 +902,13 @@ def list_models(self, list_filter, page_size, page_token): path = 'models' joiner = '?' if list_filter: - path = path + joiner + 'listFilter=\'{0}\''.format(list_filter) + path = path + joiner + 'list_filter=\'{0}\''.format(list_filter) joiner = '&' if page_size: - path = path + joiner + 'pageSize={0}'.format(page_size) + path = path + joiner + 'page_size={0}'.format(page_size) joiner = '&' if page_token: - path = path + joiner + 'pageToken={0}'.format(page_token) + path = path + joiner + 'page_token={0}'.format(page_token) try: return self._client.body('get', url=path) except requests.exceptions.RequestException as error: diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 577e5b420..8926be6dc 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -979,7 +979,7 @@ def test_list_models_with_all_args(self): assert recorder[0].method == 'GET' assert recorder[0].url == ( TestListModels._url(PROJECT_ID) + - '?listFilter=\'display_name=displayName3\'&pageSize=10&pageToken={0}' + '?list_filter=\'display_name=displayName3\'&page_size=10&page_token={0}' .format(PAGE_TOKEN)) assert isinstance(models_page, mlkit.ListModelsPage) assert len(models_page.models) == 1 From 5633b9f32e9a198519107d095f5df1360a4b7838 Mon Sep 17 00:00:00 2001 From: ifielker Date: Wed, 4 Dec 2019 17:21:26 -0500 Subject: [PATCH 09/15] fix typo --- firebase_admin/mlkit.py | 2 +- tests/test_mlkit.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 624a4000d..17d15989f 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -902,7 +902,7 @@ def list_models(self, list_filter, page_size, page_token): path = 'models' joiner = '?' if list_filter: - path = path + joiner + 'list_filter=\'{0}\''.format(list_filter) + path = path + joiner + 'filter=\'{0}\''.format(list_filter) joiner = '&' if page_size: path = path + joiner + 'page_size={0}'.format(page_size) diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 8926be6dc..5f12dcb7e 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -979,7 +979,7 @@ def test_list_models_with_all_args(self): assert recorder[0].method == 'GET' assert recorder[0].url == ( TestListModels._url(PROJECT_ID) + - '?list_filter=\'display_name=displayName3\'&page_size=10&page_token={0}' + '?filter=\'display_name=displayName3\'&page_size=10&page_token={0}' .format(PAGE_TOKEN)) assert isinstance(models_page, mlkit.ListModelsPage) assert len(models_page.models) == 1 From d715628543d73000c1f7f803cdc418c653a52840 Mon Sep 17 00:00:00 2001 From: ifielker Date: Wed, 4 Dec 2019 17:49:10 -0500 Subject: [PATCH 10/15] fix typo --- firebase_admin/mlkit.py | 2 +- tests/test_mlkit.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 17d15989f..050a1ffba 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -902,7 +902,7 @@ def list_models(self, list_filter, page_size, page_token): path = 'models' joiner = '?' if list_filter: - path = path + joiner + 'filter=\'{0}\''.format(list_filter) + path = path + joiner + 'filter={0}'.format(list_filter) joiner = '&' if page_size: path = path + joiner + 'page_size={0}'.format(page_size) diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 5f12dcb7e..0daec946b 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -979,7 +979,7 @@ def test_list_models_with_all_args(self): assert recorder[0].method == 'GET' assert recorder[0].url == ( TestListModels._url(PROJECT_ID) + - '?filter=\'display_name=displayName3\'&page_size=10&page_token={0}' + '?filter=display_name=displayName3&page_size=10&page_token={0}' .format(PAGE_TOKEN)) assert isinstance(models_page, mlkit.ListModelsPage) assert len(models_page.models) == 1 From b63a4feca17dd52b0c808c7cbcdcbabbc328219d Mon Sep 17 00:00:00 2001 From: ifielker Date: Wed, 4 Dec 2019 18:15:35 -0500 Subject: [PATCH 11/15] urlEncode filter string --- firebase_admin/mlkit.py | 16 +++++++++------- tests/test_mlkit.py | 2 +- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 050a1ffba..b75dd2233 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -27,6 +27,7 @@ import six +from six.moves import urllib from firebase_admin import _http_client from firebase_admin import _utils from firebase_admin import exceptions @@ -899,16 +900,17 @@ def list_models(self, list_filter, page_size, page_token): _validate_list_filter(list_filter) _validate_page_size(page_size) _validate_page_token(page_token) - path = 'models' - joiner = '?' + params = {} if list_filter: - path = path + joiner + 'filter={0}'.format(list_filter) - joiner = '&' + params['filter'] = list_filter if page_size: - path = path + joiner + 'page_size={0}'.format(page_size) - joiner = '&' + params['page_size'] = page_size if page_token: - path = path + joiner + 'page_token={0}'.format(page_token) + params['page_token'] = page_token + path = 'models' + if params != {}: + param_str = urllib.parse.urlencode(sorted(params.items()), True) + path = path + '?' + param_str try: return self._client.body('get', url=path) except requests.exceptions.RequestException as error: diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 0daec946b..3198eff34 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -979,7 +979,7 @@ def test_list_models_with_all_args(self): assert recorder[0].method == 'GET' assert recorder[0].url == ( TestListModels._url(PROJECT_ID) + - '?filter=display_name=displayName3&page_size=10&page_token={0}' + '?filter=display_name%3DdisplayName3&page_size=10&page_token={0}' .format(PAGE_TOKEN)) assert isinstance(models_page, mlkit.ListModelsPage) assert len(models_page.models) == 1 From 71e6705464b5e5645ed8cdfd85a575cc2647d228 Mon Sep 17 00:00:00 2001 From: ifielker Date: Wed, 4 Dec 2019 19:49:09 -0500 Subject: [PATCH 12/15] adding optional file names to conversion functions (start) --- firebase_admin/mlkit.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index b75dd2233..f810a55da 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -51,6 +51,7 @@ _MODEL_ID_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') _DISPLAY_NAME_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') _TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') +_MODEL_FILE_NAME_PATTERN = re.compile(r'^[A-Za-z0-9_]{1,60}.tflite$') _GCS_TFLITE_URI_PATTERN = re.compile( r'^gs://(?P[a-z0-9_.-]{3,63})/(?P.+)$') _RESOURCE_NAME_PATTERN = re.compile( @@ -522,12 +523,19 @@ def _tf_convert_from_keras_model(keras_model): converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) return converter.convert() + @staticmethod + def _validate_model_file_name(file_name): + if not _MODEL_FILE_NAME_PATTERN.match(file_name): + raise ValueError('Model file name format is invalid.') + return file_name + @classmethod - def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None): + def from_saved_model(cls, saved_model_dir, model_file_name='firebase_mlkit_model.tflite', bucket_name=None, app=None): """Creates a Tensor Flow Lite model from the saved model, and uploads the model to GCS. Args: saved_model_dir: The saved model directory. + model_file_name: The name that the tflite model will be saved as in Cloud Storage. bucket_name: The name of an existing bucket. None to use the default bucket configured in the app. app: Optional. A Firebase app instance (or None to use the default app) @@ -538,11 +546,11 @@ def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None): Raises: ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed. """ + file_name = TFLiteGCSModelSource._validate_model_file_name(model_file_name) TFLiteGCSModelSource._assert_tf_enabled() tflite_model = TFLiteGCSModelSource._tf_convert_from_saved_model(saved_model_dir) - open('firebase_mlkit_model.tflite', 'wb').write(tflite_model) - return TFLiteGCSModelSource.from_tflite_model_file( - 'firebase_mlkit_model.tflite', bucket_name, app) + open(file_name, 'wb').write(tflite_model) + return TFLiteGCSModelSource.from_tflite_model_file(file_name, bucket_name, app) @classmethod def from_keras_model(cls, keras_model, bucket_name=None, app=None): @@ -909,6 +917,7 @@ def list_models(self, list_filter, page_size, page_token): params['page_token'] = page_token path = 'models' if params != {}: + # pylint: disable=too-many-function-args param_str = urllib.parse.urlencode(sorted(params.items()), True) path = path + '?' + param_str try: From d24e7420d63a86a42d986efed4ff67fb010561aa Mon Sep 17 00:00:00 2001 From: ifielker Date: Fri, 13 Dec 2019 15:32:17 -0500 Subject: [PATCH 13/15] adding File naming capability for ModelSource --- firebase_admin/ml.py | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index 540be1520..61dfedd80 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -51,7 +51,6 @@ _MODEL_ID_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') _DISPLAY_NAME_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') _TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') -_MODEL_FILE_NAME_PATTERN = re.compile(r'^[A-Za-z0-9_]{1,60}.tflite$') _GCS_TFLITE_URI_PATTERN = re.compile( r'^gs://(?P[a-z0-9_.-]{3,63})/(?P.+)$') _RESOURCE_NAME_PATTERN = re.compile( @@ -524,14 +523,8 @@ def _tf_convert_from_keras_model(keras_model): converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) return converter.convert() - @staticmethod - def _validate_model_file_name(file_name): - if not _MODEL_FILE_NAME_PATTERN.match(file_name): - raise ValueError('Model file name format is invalid.') - return file_name - @classmethod - def from_saved_model(cls, saved_model_dir, model_file_name='firebase_mlkit_model.tflite', bucket_name=None, app=None): + def from_saved_model(cls, saved_model_dir, model_file_name='firebase_ml_model.tflite', bucket_name=None, app=None): """Creates a Tensor Flow Lite model from the saved model, and uploads the model to GCS. Args: @@ -547,18 +540,18 @@ def from_saved_model(cls, saved_model_dir, model_file_name='firebase_mlkit_model Raises: ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed. """ - file_name = TFLiteGCSModelSource._validate_model_file_name(model_file_name) TFLiteGCSModelSource._assert_tf_enabled() tflite_model = TFLiteGCSModelSource._tf_convert_from_saved_model(saved_model_dir) - open(file_name, 'wb').write(tflite_model) - return TFLiteGCSModelSource.from_tflite_model_file(file_name, bucket_name, app) + open(model_file_name, 'wb').write(tflite_model) + return TFLiteGCSModelSource.from_tflite_model_file(model_file_name, bucket_name, app) @classmethod - def from_keras_model(cls, keras_model, bucket_name=None, app=None): + def from_keras_model(cls, keras_model, model_file_name='firebase_ml_model.tflite', bucket_name=None, app=None): """Creates a Tensor Flow Lite model from the keras model, and uploads the model to GCS. Args: keras_model: A tf.keras model. + model_file_name: The name that the tflite model will be saved as in Cloud Storage. bucket_name: The name of an existing bucket. None to use the default bucket configured in the app. app: Optional. A Firebase app instance (or None to use the default app) @@ -571,9 +564,8 @@ def from_keras_model(cls, keras_model, bucket_name=None, app=None): """ TFLiteGCSModelSource._assert_tf_enabled() tflite_model = TFLiteGCSModelSource._tf_convert_from_keras_model(keras_model) - open('firebase_ml_model.tflite', 'wb').write(tflite_model) - return TFLiteGCSModelSource.from_tflite_model_file( - 'firebase_ml_model.tflite', bucket_name, app) + open(model_file_name, 'wb').write(tflite_model) + return TFLiteGCSModelSource.from_tflite_model_file(model_file_name, bucket_name, app) @property def gcs_tflite_uri(self): From 0233bfac96bb5b329b87ef8dd60d5e394e08a164 Mon Sep 17 00:00:00 2001 From: ifielker Date: Fri, 13 Dec 2019 15:42:45 -0500 Subject: [PATCH 14/15] fixed lint --- firebase_admin/ml.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index 61dfedd80..f49c43f10 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -524,7 +524,8 @@ def _tf_convert_from_keras_model(keras_model): return converter.convert() @classmethod - def from_saved_model(cls, saved_model_dir, model_file_name='firebase_ml_model.tflite', bucket_name=None, app=None): + def from_saved_model(cls, saved_model_dir, model_file_name='firebase_ml_model.tflite', + bucket_name=None, app=None): """Creates a Tensor Flow Lite model from the saved model, and uploads the model to GCS. Args: @@ -546,7 +547,8 @@ def from_saved_model(cls, saved_model_dir, model_file_name='firebase_ml_model.tf return TFLiteGCSModelSource.from_tflite_model_file(model_file_name, bucket_name, app) @classmethod - def from_keras_model(cls, keras_model, model_file_name='firebase_ml_model.tflite', bucket_name=None, app=None): + def from_keras_model(cls, keras_model, model_file_name='firebase_ml_model.tflite', + bucket_name=None, app=None): """Creates a Tensor Flow Lite model from the keras model, and uploads the model to GCS. Args: From 7230667cae575b4cf23d9e1c37008d3063282b6e Mon Sep 17 00:00:00 2001 From: ifielker Date: Fri, 13 Dec 2019 16:14:25 -0500 Subject: [PATCH 15/15] fixed file descriptor leak --- firebase_admin/ml.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index f49c43f10..c6720f081 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -543,7 +543,8 @@ def from_saved_model(cls, saved_model_dir, model_file_name='firebase_ml_model.tf """ TFLiteGCSModelSource._assert_tf_enabled() tflite_model = TFLiteGCSModelSource._tf_convert_from_saved_model(saved_model_dir) - open(model_file_name, 'wb').write(tflite_model) + with open(model_file_name, 'wb') as model_file: + model_file.write(tflite_model) return TFLiteGCSModelSource.from_tflite_model_file(model_file_name, bucket_name, app) @classmethod @@ -566,7 +567,8 @@ def from_keras_model(cls, keras_model, model_file_name='firebase_ml_model.tflite """ TFLiteGCSModelSource._assert_tf_enabled() tflite_model = TFLiteGCSModelSource._tf_convert_from_keras_model(keras_model) - open(model_file_name, 'wb').write(tflite_model) + with open(model_file_name, 'wb') as model_file: + model_file.write(tflite_model) return TFLiteGCSModelSource.from_tflite_model_file(model_file_name, bucket_name, app) @property