From f0aedda2c37feedf46599fe6935c6d4cfd568a50 Mon Sep 17 00:00:00 2001 From: ifielker Date: Wed, 27 Nov 2019 23:45:23 -0500 Subject: [PATCH 1/3] Firebase ML Kit Modify Operation Handling to not require a name for Done Operations --- firebase_admin/mlkit.py | 40 +++++++++++++++++++-------------- tests/test_mlkit.py | 49 +++++++++++------------------------------ 2 files changed, 37 insertions(+), 52 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index a135d2d7f..72a34395e 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -810,28 +810,36 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds """ if not isinstance(operation, dict): raise TypeError('Operation must be a dictionary.') - op_name = operation.get('name') - _, model_id = _validate_and_parse_operation_name(op_name) - - current_attempt = 0 - start_time = datetime.datetime.now() - stop_time = (None if max_time_seconds is None else - start_time + datetime.timedelta(seconds=max_time_seconds)) - while wait_for_operation and not operation.get('done'): - # We just got this operation. Wait before getting another - # so we don't exceed the GetOperation maximum request rate. - self._exponential_backoff(current_attempt, stop_time) - operation = self.get_operation(op_name) - current_attempt += 1 if operation.get('done'): + # Operations which are immediately done don't have an operation name if operation.get('response'): return operation.get('response') elif operation.get('error'): raise _utils.handle_operation_error(operation.get('error')) - - # If the operation is not complete or timed out, return a (locked) model instead - return get_model(model_id).as_dict() + raise exceptions.UnknownError(message='Internal Error: Malformed Operation.') + else: + op_name = operation.get('name') + _, model_id = _validate_and_parse_operation_name(op_name) + current_attempt = 0 + start_time = datetime.datetime.now() + stop_time = (None if max_time_seconds is None else + start_time + datetime.timedelta(seconds=max_time_seconds)) + while wait_for_operation and not operation.get('done'): + # We just got this operation. Wait before getting another + # so we don't exceed the GetOperation maximum request rate. + self._exponential_backoff(current_attempt, stop_time) + operation = self.get_operation(op_name) + current_attempt += 1 + + if operation.get('done'): + if operation.get('response'): + return operation.get('response') + elif operation.get('error'): + raise _utils.handle_operation_error(operation.get('error')) + + # If the operation is not complete or timed out, return a (locked) model instead + return get_model(model_id).as_dict() def create_model(self, model): diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 26afdfa99..fbe31aec4 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -158,23 +158,21 @@ } OPERATION_DONE_MODEL_JSON_1 = { - 'name': OPERATION_NAME_1, 'done': True, 'response': CREATED_UPDATED_MODEL_JSON_1 } OPERATION_MALFORMED_JSON_1 = { - 'name': OPERATION_NAME_1, 'done': True, # if done is true then either response or error should be populated } OPERATION_MISSING_NAME = { + # Name is required if the operation is not done. 'done': False } OPERATION_ERROR_CODE = 400 OPERATION_ERROR_MSG = "Invalid argument" OPERATION_ERROR_EXPECTED_STATUS = 'INVALID_ARGUMENT' OPERATION_ERROR_JSON_1 = { - 'name': OPERATION_NAME_1, 'done': True, 'error': { 'code': OPERATION_ERROR_CODE, @@ -609,17 +607,10 @@ def test_operation_error(self): check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) def test_malformed_operation(self): - recorder = instrument_mlkit_service( - status=[200, 200], - payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE]) - expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) - model = mlkit.create_model(MODEL_1) - assert model == expected_model - assert len(recorder) == 2 - assert recorder[0].method == 'POST' - assert recorder[0].url == TestCreateModel._url(PROJECT_ID) - assert recorder[1].method == 'GET' - assert recorder[1].url == TestCreateModel._get_url(PROJECT_ID, MODEL_ID_1) + instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) + with pytest.raises(Exception) as excinfo: + mlkit.create_model(MODEL_1) + check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.') def test_rpc_error_create(self): create_recorder = instrument_mlkit_service( @@ -708,17 +699,10 @@ def test_operation_error(self): check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) def test_malformed_operation(self): - recorder = instrument_mlkit_service( - status=[200, 200], - payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE]) - expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) - model = mlkit.update_model(MODEL_1) - assert model == expected_model - assert len(recorder) == 2 - assert recorder[0].method == 'PATCH' - assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) - assert recorder[1].method == 'GET' - assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) + instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) + with pytest.raises(Exception) as excinfo: + mlkit.update_model(MODEL_1) + check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.') def test_rpc_error(self): create_recorder = instrument_mlkit_service( @@ -824,17 +808,10 @@ def test_operation_error(self, publish_function): @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_malformed_operation(self, publish_function): - recorder = instrument_mlkit_service( - status=[200, 200], - payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE]) - expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) - model = publish_function(MODEL_ID_1) - assert model == expected_model - assert len(recorder) == 2 - assert recorder[0].method == 'PATCH' - assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) - assert recorder[1].method == 'GET' - assert recorder[1].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) + instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) + with pytest.raises(Exception) as excinfo: + publish_function(MODEL_ID_1) + check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.') @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_rpc_error(self, publish_function): From 8939ed60c991ce27c2d38bd57e1718e0a580b2e6 Mon Sep 17 00:00:00 2001 From: ifielker Date: Thu, 28 Nov 2019 00:00:25 -0500 Subject: [PATCH 2/3] fixed lint --- firebase_admin/mlkit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 72a34395e..da8bda3c8 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -824,7 +824,7 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds current_attempt = 0 start_time = datetime.datetime.now() stop_time = (None if max_time_seconds is None else - start_time + datetime.timedelta(seconds=max_time_seconds)) + start_time + datetime.timedelta(seconds=max_time_seconds)) while wait_for_operation and not operation.get('done'): # We just got this operation. Wait before getting another # so we don't exceed the GetOperation maximum request rate. From 74d11d81946175ba41091c0141cbdea0b6994c78 Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 10 Dec 2019 12:44:53 -0500 Subject: [PATCH 3/3] Adding support for TensorFlow 2.x (#372) Adding support for TensorFlow 2.x and moving some payloads into parameters --- firebase_admin/mlkit.py | 60 ++++++++++++++++++++++++++++------------- tests/test_mlkit.py | 27 ++++++++++--------- 2 files changed, 56 insertions(+), 31 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index da8bda3c8..bb277abf9 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -27,6 +27,7 @@ import six +from six.moves import urllib from firebase_admin import _http_client from firebase_admin import _utils from firebase_admin import exceptions @@ -200,6 +201,7 @@ def from_dict(cls, data, app=None): data_copy = dict(data) tflite_format = None tflite_format_data = data_copy.pop('tfliteModel', None) + data_copy.pop('@type', None) # Returned by Operations. (Not needed) if tflite_format_data: tflite_format = TFLiteFormat.from_dict(tflite_format_data) model = Model(model_format=tflite_format) @@ -495,12 +497,31 @@ def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None): return TFLiteGCSModelSource(gcs_tflite_uri=gcs_uri, app=app) @staticmethod - def _assert_tf_version_1_enabled(): + def _assert_tf_enabled(): if not _TF_ENABLED: raise ImportError('Failed to import the tensorflow library for Python. Make sure ' 'to install the tensorflow module.') - if not tf.VERSION.startswith('1.'): - raise ImportError('Expected tensorflow version 1.x, but found {0}'.format(tf.VERSION)) + if not tf.version.VERSION.startswith('1.') and not tf.version.VERSION.startswith('2.'): + raise ImportError('Expected tensorflow version 1.x or 2.x, but found {0}' + .format(tf.version.VERSION)) + + @staticmethod + def _tf_convert_from_saved_model(saved_model_dir): + # Same for both v1.x and v2.x + converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) + return converter.convert() + + @staticmethod + def _tf_convert_from_keras_model(keras_model): + # Version 1.x conversion function takes a model file. Version 2.x takes the model itself. + if tf.version.VERSION.startswith('1.'): + keras_file = 'firebase_keras_model.h5' + tf.keras.models.save_model(keras_model, keras_file) + converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) + return converter.convert() + else: + converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) + return converter.convert() @classmethod def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None): @@ -518,9 +539,8 @@ def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None): Raises: ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed. """ - TFLiteGCSModelSource._assert_tf_version_1_enabled() - converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) - tflite_model = converter.convert() + TFLiteGCSModelSource._assert_tf_enabled() + tflite_model = TFLiteGCSModelSource._tf_convert_from_saved_model(saved_model_dir) open('firebase_mlkit_model.tflite', 'wb').write(tflite_model) return TFLiteGCSModelSource.from_tflite_model_file( 'firebase_mlkit_model.tflite', bucket_name, app) @@ -541,11 +561,8 @@ def from_keras_model(cls, keras_model, bucket_name=None, app=None): Raises: ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed. """ - TFLiteGCSModelSource._assert_tf_version_1_enabled() - keras_file = 'keras_model.h5' - tf.keras.models.save_model(keras_model, keras_file) - converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) - tflite_model = converter.convert() + TFLiteGCSModelSource._assert_tf_enabled() + tflite_model = TFLiteGCSModelSource._tf_convert_from_keras_model(keras_model) open('firebase_mlkit_model.tflite', 'wb').write(tflite_model) return TFLiteGCSModelSource.from_tflite_model_file( 'firebase_mlkit_model.tflite', bucket_name, app) @@ -852,12 +869,12 @@ def create_model(self, model): def update_model(self, model, update_mask=None): _validate_model(model, update_mask) - data = {'model': model.as_dict(for_upload=True)} + path = 'models/{0}'.format(model.model_id) if update_mask is not None: - data['updateMask'] = update_mask + path = path + '?updateMask={0}'.format(update_mask) try: return self.handle_operation( - self._client.body('patch', url='models/{0}'.format(model.model_id), json=data)) + self._client.body('patch', url=path, json=model.as_dict(for_upload=True))) except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) @@ -884,15 +901,20 @@ def list_models(self, list_filter, page_size, page_token): _validate_list_filter(list_filter) _validate_page_size(page_size) _validate_page_token(page_token) - payload = {} + params = {} if list_filter: - payload['list_filter'] = list_filter + params['filter'] = list_filter if page_size: - payload['page_size'] = page_size + params['page_size'] = page_size if page_token: - payload['page_token'] = page_token + params['page_token'] = page_token + path = 'models' + if params: + # pylint: disable=too-many-function-args + param_str = urllib.parse.urlencode(sorted(params.items()), True) + path = path + '?' + param_str try: - return self._client.body('get', url='models', json=payload) + return self._client.body('get', url=path) except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index fbe31aec4..dbe590673 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -763,7 +763,13 @@ def teardown_class(cls): testutils.cleanup_apps() @staticmethod - def _url(project_id, model_id): + def _update_url(project_id, model_id): + update_url = 'projects/{0}/models/{1}?updateMask=state.published'.format( + project_id, model_id) + return BASE_URL + update_url + + @staticmethod + def _get_url(project_id, model_id): return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) @staticmethod @@ -778,10 +784,9 @@ def test_immediate_done(self, publish_function, published): assert model == CREATED_UPDATED_MODEL_1 assert len(recorder) == 1 assert recorder[0].method == 'PATCH' - assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1) body = json.loads(recorder[0].body.decode()) - assert body.get('model', {}).get('state', {}).get('published', None) is published - assert body.get('updateMask', {}) == 'state.published' + assert body.get('state', {}).get('published', None) is published @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_returns_locked(self, publish_function): @@ -794,9 +799,9 @@ def test_returns_locked(self, publish_function): assert model == expected_model assert len(recorder) == 2 assert recorder[0].method == 'PATCH' - assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1) assert recorder[1].method == 'GET' - assert recorder[1].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) + assert recorder[1].url == TestPublishUnpublish._get_url(PROJECT_ID, MODEL_ID_1) @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_operation_error(self, publish_function): @@ -973,12 +978,10 @@ def test_list_models_with_all_args(self): page_token=PAGE_TOKEN) assert len(recorder) == 1 assert recorder[0].method == 'GET' - assert recorder[0].url == TestListModels._url(PROJECT_ID) - assert json.loads(recorder[0].body.decode()) == { - 'list_filter': 'display_name=displayName3', - 'page_size': 10, - 'page_token': PAGE_TOKEN - } + assert recorder[0].url == ( + TestListModels._url(PROJECT_ID) + + '?filter=display_name%3DdisplayName3&page_size=10&page_token={0}' + .format(PAGE_TOKEN)) assert isinstance(models_page, mlkit.ListModelsPage) assert len(models_page.models) == 1 assert models_page.models[0] == MODEL_3