Skip to content

Commit f0aedda

Browse files
committed
Firebase ML Kit Modify Operation Handling to not require a name for Done Operations
1 parent 7b4731f commit f0aedda

File tree

2 files changed

+37
-52
lines changed

2 files changed

+37
-52
lines changed

firebase_admin/mlkit.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -810,28 +810,36 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds
810810
"""
811811
if not isinstance(operation, dict):
812812
raise TypeError('Operation must be a dictionary.')
813-
op_name = operation.get('name')
814-
_, model_id = _validate_and_parse_operation_name(op_name)
815-
816-
current_attempt = 0
817-
start_time = datetime.datetime.now()
818-
stop_time = (None if max_time_seconds is None else
819-
start_time + datetime.timedelta(seconds=max_time_seconds))
820-
while wait_for_operation and not operation.get('done'):
821-
# We just got this operation. Wait before getting another
822-
# so we don't exceed the GetOperation maximum request rate.
823-
self._exponential_backoff(current_attempt, stop_time)
824-
operation = self.get_operation(op_name)
825-
current_attempt += 1
826813

827814
if operation.get('done'):
815+
# Operations which are immediately done don't have an operation name
828816
if operation.get('response'):
829817
return operation.get('response')
830818
elif operation.get('error'):
831819
raise _utils.handle_operation_error(operation.get('error'))
832-
833-
# If the operation is not complete or timed out, return a (locked) model instead
834-
return get_model(model_id).as_dict()
820+
raise exceptions.UnknownError(message='Internal Error: Malformed Operation.')
821+
else:
822+
op_name = operation.get('name')
823+
_, model_id = _validate_and_parse_operation_name(op_name)
824+
current_attempt = 0
825+
start_time = datetime.datetime.now()
826+
stop_time = (None if max_time_seconds is None else
827+
start_time + datetime.timedelta(seconds=max_time_seconds))
828+
while wait_for_operation and not operation.get('done'):
829+
# We just got this operation. Wait before getting another
830+
# so we don't exceed the GetOperation maximum request rate.
831+
self._exponential_backoff(current_attempt, stop_time)
832+
operation = self.get_operation(op_name)
833+
current_attempt += 1
834+
835+
if operation.get('done'):
836+
if operation.get('response'):
837+
return operation.get('response')
838+
elif operation.get('error'):
839+
raise _utils.handle_operation_error(operation.get('error'))
840+
841+
# If the operation is not complete or timed out, return a (locked) model instead
842+
return get_model(model_id).as_dict()
835843

836844

837845
def create_model(self, model):

tests/test_mlkit.py

Lines changed: 13 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -158,23 +158,21 @@
158158
}
159159

160160
OPERATION_DONE_MODEL_JSON_1 = {
161-
'name': OPERATION_NAME_1,
162161
'done': True,
163162
'response': CREATED_UPDATED_MODEL_JSON_1
164163
}
165164
OPERATION_MALFORMED_JSON_1 = {
166-
'name': OPERATION_NAME_1,
167165
'done': True,
168166
# if done is true then either response or error should be populated
169167
}
170168
OPERATION_MISSING_NAME = {
169+
# Name is required if the operation is not done.
171170
'done': False
172171
}
173172
OPERATION_ERROR_CODE = 400
174173
OPERATION_ERROR_MSG = "Invalid argument"
175174
OPERATION_ERROR_EXPECTED_STATUS = 'INVALID_ARGUMENT'
176175
OPERATION_ERROR_JSON_1 = {
177-
'name': OPERATION_NAME_1,
178176
'done': True,
179177
'error': {
180178
'code': OPERATION_ERROR_CODE,
@@ -609,17 +607,10 @@ def test_operation_error(self):
609607
check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG)
610608

611609
def test_malformed_operation(self):
612-
recorder = instrument_mlkit_service(
613-
status=[200, 200],
614-
payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE])
615-
expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2)
616-
model = mlkit.create_model(MODEL_1)
617-
assert model == expected_model
618-
assert len(recorder) == 2
619-
assert recorder[0].method == 'POST'
620-
assert recorder[0].url == TestCreateModel._url(PROJECT_ID)
621-
assert recorder[1].method == 'GET'
622-
assert recorder[1].url == TestCreateModel._get_url(PROJECT_ID, MODEL_ID_1)
610+
instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE)
611+
with pytest.raises(Exception) as excinfo:
612+
mlkit.create_model(MODEL_1)
613+
check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.')
623614

624615
def test_rpc_error_create(self):
625616
create_recorder = instrument_mlkit_service(
@@ -708,17 +699,10 @@ def test_operation_error(self):
708699
check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG)
709700

710701
def test_malformed_operation(self):
711-
recorder = instrument_mlkit_service(
712-
status=[200, 200],
713-
payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE])
714-
expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2)
715-
model = mlkit.update_model(MODEL_1)
716-
assert model == expected_model
717-
assert len(recorder) == 2
718-
assert recorder[0].method == 'PATCH'
719-
assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)
720-
assert recorder[1].method == 'GET'
721-
assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)
702+
instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE)
703+
with pytest.raises(Exception) as excinfo:
704+
mlkit.update_model(MODEL_1)
705+
check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.')
722706

723707
def test_rpc_error(self):
724708
create_recorder = instrument_mlkit_service(
@@ -824,17 +808,10 @@ def test_operation_error(self, publish_function):
824808

825809
@pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS)
826810
def test_malformed_operation(self, publish_function):
827-
recorder = instrument_mlkit_service(
828-
status=[200, 200],
829-
payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE])
830-
expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2)
831-
model = publish_function(MODEL_ID_1)
832-
assert model == expected_model
833-
assert len(recorder) == 2
834-
assert recorder[0].method == 'PATCH'
835-
assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
836-
assert recorder[1].method == 'GET'
837-
assert recorder[1].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1)
811+
instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE)
812+
with pytest.raises(Exception) as excinfo:
813+
publish_function(MODEL_ID_1)
814+
check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.')
838815

839816
@pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS)
840817
def test_rpc_error(self, publish_function):

0 commit comments

Comments
 (0)