Skip to content

Modify Operation Handling to not require a name for Done Operations #371

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 65 additions & 35 deletions firebase_admin/mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import six


from six.moves import urllib
from firebase_admin import _http_client
from firebase_admin import _utils
from firebase_admin import exceptions
Expand Down Expand Up @@ -200,6 +201,7 @@ def from_dict(cls, data, app=None):
data_copy = dict(data)
tflite_format = None
tflite_format_data = data_copy.pop('tfliteModel', None)
data_copy.pop('@type', None) # Returned by Operations. (Not needed)
if tflite_format_data:
tflite_format = TFLiteFormat.from_dict(tflite_format_data)
model = Model(model_format=tflite_format)
Expand Down Expand Up @@ -495,12 +497,31 @@ def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None):
return TFLiteGCSModelSource(gcs_tflite_uri=gcs_uri, app=app)

@staticmethod
def _assert_tf_version_1_enabled():
def _assert_tf_enabled():
if not _TF_ENABLED:
raise ImportError('Failed to import the tensorflow library for Python. Make sure '
'to install the tensorflow module.')
if not tf.VERSION.startswith('1.'):
raise ImportError('Expected tensorflow version 1.x, but found {0}'.format(tf.VERSION))
if not tf.version.VERSION.startswith('1.') and not tf.version.VERSION.startswith('2.'):
raise ImportError('Expected tensorflow version 1.x or 2.x, but found {0}'
.format(tf.version.VERSION))

@staticmethod
def _tf_convert_from_saved_model(saved_model_dir):
# Same for both v1.x and v2.x
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
return converter.convert()

@staticmethod
def _tf_convert_from_keras_model(keras_model):
# Version 1.x conversion function takes a model file. Version 2.x takes the model itself.
if tf.version.VERSION.startswith('1.'):
keras_file = 'firebase_keras_model.h5'
tf.keras.models.save_model(keras_model, keras_file)
converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)
return converter.convert()
else:
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
return converter.convert()

@classmethod
def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None):
Expand All @@ -518,9 +539,8 @@ def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None):
Raises:
ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed.
"""
TFLiteGCSModelSource._assert_tf_version_1_enabled()
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
TFLiteGCSModelSource._assert_tf_enabled()
tflite_model = TFLiteGCSModelSource._tf_convert_from_saved_model(saved_model_dir)
open('firebase_mlkit_model.tflite', 'wb').write(tflite_model)
return TFLiteGCSModelSource.from_tflite_model_file(
'firebase_mlkit_model.tflite', bucket_name, app)
Expand All @@ -541,11 +561,8 @@ def from_keras_model(cls, keras_model, bucket_name=None, app=None):
Raises:
ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed.
"""
TFLiteGCSModelSource._assert_tf_version_1_enabled()
keras_file = 'keras_model.h5'
tf.keras.models.save_model(keras_model, keras_file)
converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)
tflite_model = converter.convert()
TFLiteGCSModelSource._assert_tf_enabled()
tflite_model = TFLiteGCSModelSource._tf_convert_from_keras_model(keras_model)
open('firebase_mlkit_model.tflite', 'wb').write(tflite_model)
return TFLiteGCSModelSource.from_tflite_model_file(
'firebase_mlkit_model.tflite', bucket_name, app)
Expand Down Expand Up @@ -810,28 +827,36 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds
"""
if not isinstance(operation, dict):
raise TypeError('Operation must be a dictionary.')
op_name = operation.get('name')
_, model_id = _validate_and_parse_operation_name(op_name)

current_attempt = 0
start_time = datetime.datetime.now()
stop_time = (None if max_time_seconds is None else
start_time + datetime.timedelta(seconds=max_time_seconds))
while wait_for_operation and not operation.get('done'):
# We just got this operation. Wait before getting another
# so we don't exceed the GetOperation maximum request rate.
self._exponential_backoff(current_attempt, stop_time)
operation = self.get_operation(op_name)
current_attempt += 1

if operation.get('done'):
# Operations which are immediately done don't have an operation name
if operation.get('response'):
return operation.get('response')
elif operation.get('error'):
raise _utils.handle_operation_error(operation.get('error'))

# If the operation is not complete or timed out, return a (locked) model instead
return get_model(model_id).as_dict()
raise exceptions.UnknownError(message='Internal Error: Malformed Operation.')
else:
op_name = operation.get('name')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Consider moving this else block logic into a helper function. _poll_until_complete(operation) or similar.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that might be misleading. We only do the polling if wait_for_operation is true. So we could either only call the _poll_until_complete when wait_for_operation is true and have it only be the while loop and then duplicate the done checking below the while loop in both handle_operation and _poll_until_complete or we could name it something else like _poll_until_complete_if_waiting or something? I actually prefer the way it is now, but I'm open to other suggestions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be _poll_until_complete(operation, wait_for_operation)?

I'm fine with how it is implemented now. Just thought the else block could be a bit smaller for clarity. Your call.

_, model_id = _validate_and_parse_operation_name(op_name)
current_attempt = 0
start_time = datetime.datetime.now()
stop_time = (None if max_time_seconds is None else
start_time + datetime.timedelta(seconds=max_time_seconds))
while wait_for_operation and not operation.get('done'):
# We just got this operation. Wait before getting another
# so we don't exceed the GetOperation maximum request rate.
self._exponential_backoff(current_attempt, stop_time)
operation = self.get_operation(op_name)
current_attempt += 1

if operation.get('done'):
if operation.get('response'):
return operation.get('response')
elif operation.get('error'):
raise _utils.handle_operation_error(operation.get('error'))

# If the operation is not complete or timed out, return a (locked) model instead
return get_model(model_id).as_dict()


def create_model(self, model):
Expand All @@ -844,12 +869,12 @@ def create_model(self, model):

def update_model(self, model, update_mask=None):
_validate_model(model, update_mask)
data = {'model': model.as_dict(for_upload=True)}
path = 'models/{0}'.format(model.model_id)
if update_mask is not None:
data['updateMask'] = update_mask
path = path + '?updateMask={0}'.format(update_mask)
try:
return self.handle_operation(
self._client.body('patch', url='models/{0}'.format(model.model_id), json=data))
self._client.body('patch', url=path, json=model.as_dict(for_upload=True)))
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)

Expand All @@ -876,15 +901,20 @@ def list_models(self, list_filter, page_size, page_token):
_validate_list_filter(list_filter)
_validate_page_size(page_size)
_validate_page_token(page_token)
payload = {}
params = {}
if list_filter:
payload['list_filter'] = list_filter
params['filter'] = list_filter
if page_size:
payload['page_size'] = page_size
params['page_size'] = page_size
if page_token:
payload['page_token'] = page_token
params['page_token'] = page_token
path = 'models'
if params:
# pylint: disable=too-many-function-args
param_str = urllib.parse.urlencode(sorted(params.items()), True)
path = path + '?' + param_str
try:
return self._client.body('get', url='models', json=payload)
return self._client.body('get', url=path)
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)

Expand Down
76 changes: 28 additions & 48 deletions tests/test_mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,23 +158,21 @@
}

OPERATION_DONE_MODEL_JSON_1 = {
'name': OPERATION_NAME_1,
'done': True,
'response': CREATED_UPDATED_MODEL_JSON_1
}
OPERATION_MALFORMED_JSON_1 = {
'name': OPERATION_NAME_1,
'done': True,
# if done is true then either response or error should be populated
}
OPERATION_MISSING_NAME = {
# Name is required if the operation is not done.
'done': False
}
OPERATION_ERROR_CODE = 400
OPERATION_ERROR_MSG = "Invalid argument"
OPERATION_ERROR_EXPECTED_STATUS = 'INVALID_ARGUMENT'
OPERATION_ERROR_JSON_1 = {
'name': OPERATION_NAME_1,
'done': True,
'error': {
'code': OPERATION_ERROR_CODE,
Expand Down Expand Up @@ -609,17 +607,10 @@ def test_operation_error(self):
check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG)

def test_malformed_operation(self):
recorder = instrument_mlkit_service(
status=[200, 200],
payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE])
expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2)
model = mlkit.create_model(MODEL_1)
assert model == expected_model
assert len(recorder) == 2
assert recorder[0].method == 'POST'
assert recorder[0].url == TestCreateModel._url(PROJECT_ID)
assert recorder[1].method == 'GET'
assert recorder[1].url == TestCreateModel._get_url(PROJECT_ID, MODEL_ID_1)
instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE)
with pytest.raises(Exception) as excinfo:
mlkit.create_model(MODEL_1)
check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.')

def test_rpc_error_create(self):
create_recorder = instrument_mlkit_service(
Expand Down Expand Up @@ -708,17 +699,10 @@ def test_operation_error(self):
check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG)

def test_malformed_operation(self):
recorder = instrument_mlkit_service(
status=[200, 200],
payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE])
expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2)
model = mlkit.update_model(MODEL_1)
assert model == expected_model
assert len(recorder) == 2
assert recorder[0].method == 'PATCH'
assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)
assert recorder[1].method == 'GET'
assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)
instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE)
with pytest.raises(Exception) as excinfo:
mlkit.update_model(MODEL_1)
check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.')

def test_rpc_error(self):
create_recorder = instrument_mlkit_service(
Expand Down Expand Up @@ -779,7 +763,13 @@ def teardown_class(cls):
testutils.cleanup_apps()

@staticmethod
def _url(project_id, model_id):
def _update_url(project_id, model_id):
update_url = 'projects/{0}/models/{1}?updateMask=state.published'.format(
project_id, model_id)
return BASE_URL + update_url

@staticmethod
def _get_url(project_id, model_id):
return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id)

@staticmethod
Expand All @@ -794,10 +784,9 @@ def test_immediate_done(self, publish_function, published):
assert model == CREATED_UPDATED_MODEL_1
assert len(recorder) == 1
assert recorder[0].method == 'PATCH'
assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1)
body = json.loads(recorder[0].body.decode())
assert body.get('model', {}).get('state', {}).get('published', None) is published
assert body.get('updateMask', {}) == 'state.published'
assert body.get('state', {}).get('published', None) is published

@pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS)
def test_returns_locked(self, publish_function):
Expand All @@ -810,9 +799,9 @@ def test_returns_locked(self, publish_function):
assert model == expected_model
assert len(recorder) == 2
assert recorder[0].method == 'PATCH'
assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1)
assert recorder[1].method == 'GET'
assert recorder[1].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
assert recorder[1].url == TestPublishUnpublish._get_url(PROJECT_ID, MODEL_ID_1)

@pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS)
def test_operation_error(self, publish_function):
Expand All @@ -824,17 +813,10 @@ def test_operation_error(self, publish_function):

@pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS)
def test_malformed_operation(self, publish_function):
recorder = instrument_mlkit_service(
status=[200, 200],
payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE])
expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2)
model = publish_function(MODEL_ID_1)
assert model == expected_model
assert len(recorder) == 2
assert recorder[0].method == 'PATCH'
assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
assert recorder[1].method == 'GET'
assert recorder[1].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE)
with pytest.raises(Exception) as excinfo:
publish_function(MODEL_ID_1)
check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.')

@pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS)
def test_rpc_error(self, publish_function):
Expand Down Expand Up @@ -996,12 +978,10 @@ def test_list_models_with_all_args(self):
page_token=PAGE_TOKEN)
assert len(recorder) == 1
assert recorder[0].method == 'GET'
assert recorder[0].url == TestListModels._url(PROJECT_ID)
assert json.loads(recorder[0].body.decode()) == {
'list_filter': 'display_name=displayName3',
'page_size': 10,
'page_token': PAGE_TOKEN
}
assert recorder[0].url == (
TestListModels._url(PROJECT_ID) +
'?filter=display_name%3DdisplayName3&page_size=10&page_token={0}'
.format(PAGE_TOKEN))
assert isinstance(models_page, mlkit.ListModelsPage)
assert len(models_page.models) == 1
assert models_page.models[0] == MODEL_3
Expand Down