diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 8cf8d1f7f..b9b56c8f4 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -73,6 +73,20 @@ def create_model(model, app=None): return Model.from_dict(mlkit_service.create_model(model), app=app) +def update_model(model, app=None): + """Updates a model in Firebase ML Kit. + + Args: + model: The mlkit.Model to update. + app: A Firebase app instance (or None to use the default app). + + Returns: + Model: The updated model. + """ + mlkit_service = _get_mlkit_service(app) + return Model.from_dict(mlkit_service.update_model(model), app=app) + + def get_model(model_id, app=None): """Gets a model from Firebase ML Kit. @@ -469,10 +483,10 @@ def _validate_and_parse_name(name): return matcher.group('project_id'), matcher.group('model_id') -def _validate_model(model): +def _validate_model(model, update_mask=None): if not isinstance(model, Model): raise TypeError('Model must be an mlkit.Model.') - if not model.display_name: + if update_mask is None and not model.display_name: raise ValueError('Model must have a display name.') @@ -634,6 +648,17 @@ def create_model(self, model): except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) + def update_model(self, model, update_mask=None): + _validate_model(model, update_mask) + data = {'model': model.as_dict()} + if update_mask is not None: + data['updateMask'] = update_mask + try: + return self.handle_operation( + self._client.body('patch', url='models/{0}'.format(model.model_id), json=data)) + except requests.exceptions.RequestException as error: + raise _utils.handle_platform_error_from_requests(error) + def get_model(self, model_id): _validate_model_id(model_id) try: diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 78afbdf49..e93bbd7e9 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -24,7 +24,6 @@ from tests import testutils BASE_URL = 'https://mlkit.googleapis.com/v1beta1/' - PROJECT_ID = 'myProject1' PAGE_TOKEN = 'pageToken' NEXT_PAGE_TOKEN = 'nextPageToken' @@ -122,7 +121,7 @@ } TFLITE_FORMAT_2 = mlkit.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_2) -CREATED_MODEL_JSON_1 = { +CREATED_UPDATED_MODEL_JSON_1 = { 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_1, 'createTime': CREATE_TIME_JSON, @@ -132,7 +131,7 @@ 'modelHash': MODEL_HASH, 'tags': TAGS, } -CREATED_MODEL_1 = mlkit.Model.from_dict(CREATED_MODEL_JSON_1) +CREATED_UPDATED_MODEL_1 = mlkit.Model.from_dict(CREATED_UPDATED_MODEL_JSON_1) LOCKED_MODEL_JSON_1 = { 'name': MODEL_NAME_1, @@ -155,19 +154,16 @@ OPERATION_DONE_MODEL_JSON_1 = { 'name': OPERATION_NAME_1, 'done': True, - 'response': CREATED_MODEL_JSON_1 + 'response': CREATED_UPDATED_MODEL_JSON_1 } - OPERATION_MALFORMED_JSON_1 = { 'name': OPERATION_NAME_1, 'done': True, # if done is true then either response or error should be populated } - OPERATION_MISSING_NAME = { 'done': False } - OPERATION_ERROR_CODE = 400 OPERATION_ERROR_MSG = "Invalid argument" OPERATION_ERROR_EXPECTED_STATUS = 'INVALID_ARGUMENT' @@ -254,15 +250,33 @@ } ERROR_RESPONSE_BAD_REQUEST = json.dumps(ERROR_JSON_BAD_REQUEST) -invalid_model_id_args = [ +INVALID_MODEL_ID_ARGS = [ ('', ValueError), ('&_*#@:/?', ValueError), (None, TypeError), (12345, TypeError), ] +INVALID_MODEL_ARGS = [ + 'abc', + 4.2, + list(), + dict(), + True, + -1, + 0, + None +] +INVALID_OP_NAME_ARGS = [ + 'abc', + '123', + 'projects/operations/project/1234/model/abc/operation/123', + 'operations/project/model/abc/operation/123', + 'operations/project/123/model/$#@/operation/123', + 'operations/project/1234/model/abc/operation/123/extrathing', +] PAGE_SIZE_VALUE_ERROR_MSG = 'Page size must be a positive integer between ' \ '1 and {0}'.format(mlkit._MAX_PAGE_SIZE) -invalid_string_or_none_args = [0, -1, 4.2, 0x10, False, list(), dict()] +INVALID_STRING_OR_NONE_ARGS = [0, -1, 4.2, 0x10, False, list(), dict()] # For validation type errors @@ -524,7 +538,7 @@ def _get_url(project_id, model_id): def test_immediate_done(self): instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE) model = mlkit.create_model(MODEL_1) - assert model == CREATED_MODEL_1 + assert model == CREATED_UPDATED_MODEL_1 def test_returns_locked(self): recorder = instrument_mlkit_service( @@ -573,16 +587,7 @@ def test_rpc_error_create(self): ) assert len(create_recorder) == 1 - @pytest.mark.parametrize('model', [ - 'abc', - 4.2, - list(), - dict(), - True, - -1, - 0, - None - ]) + @pytest.mark.parametrize('model', INVALID_MODEL_ARGS) def test_not_model(self, model): with pytest.raises(Exception) as excinfo: mlkit.create_model(model) @@ -599,14 +604,7 @@ def test_missing_op_name(self): mlkit.create_model(MODEL_1) check_error(excinfo, TypeError) - @pytest.mark.parametrize('op_name', [ - 'abc', - '123', - 'projects/operations/project/1234/model/abc/operation/123', - 'operations/project/model/abc/operation/123', - 'operations/project/123/model/$#@/operation/123', - 'operations/project/1234/model/abc/operation/123/extrathing', - ]) + @pytest.mark.parametrize('op_name', INVALID_OP_NAME_ARGS) def test_invalid_op_name(self, op_name): payload = json.dumps({'name': op_name}) instrument_mlkit_service(status=200, payload=payload) @@ -615,6 +613,105 @@ def test_invalid_op_name(self, op_name): check_error(excinfo, ValueError, 'Operation name format is invalid.') +class TestUpdateModel(object): + """Tests mlkit.update_model.""" + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def _url(project_id, model_id): + return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + + @staticmethod + def _op_url(project_id, model_id): + return BASE_URL + \ + 'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id) + + def test_immediate_done(self): + instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE) + model = mlkit.update_model(MODEL_1) + assert model == CREATED_UPDATED_MODEL_1 + + def test_returns_locked(self): + recorder = instrument_mlkit_service( + status=[200, 200], + payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE]) + expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) + model = mlkit.update_model(MODEL_1) + + assert model == expected_model + assert len(recorder) == 2 + assert recorder[0].method == 'PATCH' + assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[1].method == 'GET' + assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) + + def test_operation_error(self): + instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE) + with pytest.raises(Exception) as excinfo: + mlkit.update_model(MODEL_1) + # The http request succeeded, the operation returned contains a create failure + check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) + + def test_malformed_operation(self): + recorder = instrument_mlkit_service( + status=[200, 200], + payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE]) + expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) + model = mlkit.update_model(MODEL_1) + assert model == expected_model + assert len(recorder) == 2 + assert recorder[0].method == 'PATCH' + assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[1].method == 'GET' + assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) + + def test_rpc_error_create(self): + create_recorder = instrument_mlkit_service( + status=400, payload=ERROR_RESPONSE_BAD_REQUEST) + with pytest.raises(Exception) as excinfo: + mlkit.update_model(MODEL_1) + check_firebase_error( + excinfo, + ERROR_STATUS_BAD_REQUEST, + ERROR_CODE_BAD_REQUEST, + ERROR_MSG_BAD_REQUEST + ) + assert len(create_recorder) == 1 + + @pytest.mark.parametrize('model', INVALID_MODEL_ARGS) + def test_not_model(self, model): + with pytest.raises(Exception) as excinfo: + mlkit.update_model(model) + check_error(excinfo, TypeError, 'Model must be an mlkit.Model.') + + def test_missing_display_name(self): + with pytest.raises(Exception) as excinfo: + mlkit.update_model(mlkit.Model.from_dict({})) + check_error(excinfo, ValueError, 'Model must have a display name.') + + def test_missing_op_name(self): + instrument_mlkit_service(status=200, payload=OPERATION_MISSING_NAME_RESPONSE) + with pytest.raises(Exception) as excinfo: + mlkit.update_model(MODEL_1) + check_error(excinfo, TypeError) + + @pytest.mark.parametrize('op_name', INVALID_OP_NAME_ARGS) + def test_invalid_op_name(self, op_name): + payload = json.dumps({'name': op_name}) + instrument_mlkit_service(status=200, payload=payload) + with pytest.raises(Exception) as excinfo: + mlkit.update_model(MODEL_1) + check_error(excinfo, ValueError, 'Operation name format is invalid.') + + class TestGetModel(object): """Tests mlkit.get_model.""" @classmethod @@ -640,7 +737,7 @@ def test_get_model(self): assert model.model_id == MODEL_ID_1 assert model.display_name == DISPLAY_NAME_1 - @pytest.mark.parametrize('model_id, exc_type', invalid_model_id_args) + @pytest.mark.parametrize('model_id, exc_type', INVALID_MODEL_ID_ARGS) def test_get_model_validation_errors(self, model_id, exc_type): with pytest.raises(exc_type) as excinfo: mlkit.get_model(model_id) @@ -690,7 +787,7 @@ def test_delete_model(self): assert recorder[0].method == 'DELETE' assert recorder[0].url == TestDeleteModel._url(PROJECT_ID, MODEL_ID_1) - @pytest.mark.parametrize('model_id, exc_type', invalid_model_id_args) + @pytest.mark.parametrize('model_id, exc_type', INVALID_MODEL_ID_ARGS) def test_delete_model_validation_errors(self, model_id, exc_type): with pytest.raises(exc_type) as excinfo: mlkit.delete_model(model_id) @@ -771,7 +868,7 @@ def test_list_models_with_all_args(self): assert models_page.models[0] == MODEL_3 assert not models_page.has_next_page - @pytest.mark.parametrize('list_filter', invalid_string_or_none_args) + @pytest.mark.parametrize('list_filter', INVALID_STRING_OR_NONE_ARGS) def test_list_models_list_filter_validation(self, list_filter): with pytest.raises(TypeError) as excinfo: mlkit.list_models(list_filter=list_filter) @@ -792,7 +889,7 @@ def test_list_models_page_size_validation(self, page_size, exc_type, error_messa mlkit.list_models(page_size=page_size) check_error(excinfo, exc_type, error_message) - @pytest.mark.parametrize('page_token', invalid_string_or_none_args) + @pytest.mark.parametrize('page_token', INVALID_STRING_OR_NONE_ARGS) def test_list_models_page_token_validation(self, page_token): with pytest.raises(TypeError) as excinfo: mlkit.list_models(page_token=page_token)