diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index e86a827e0..0e42419c6 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -49,6 +49,11 @@ def get_model(model_id, app=None): return Model(mlkit_service.get_model(model_id)) +def delete_model(model_id, app=None): + mlkit_service = _get_mlkit_service(app) + mlkit_service.delete_model(model_id) + + class Model(object): """A Firebase ML Kit Model object.""" def __init__(self, data): @@ -67,6 +72,13 @@ def __ne__(self, other): #TODO(ifielker): define the Model properties etc +def _validate_model_id(model_id): + if not isinstance(model_id, six.string_types): + raise TypeError('Model ID must be a string.') + if not re.match(r'^[A-Za-z0-9_-]{1,60}$', model_id): + raise ValueError('Model ID format is invalid.') + + class _MLKitService(object): """Firebase MLKit service.""" @@ -84,11 +96,15 @@ def __init__(self, app): base_url=self._project_url) def get_model(self, model_id): - if not isinstance(model_id, six.string_types): - raise TypeError('Model ID must be a string.') - if not re.match(r'^[A-Za-z0-9_-]{1,60}$', model_id): - raise ValueError('Model ID format is invalid.') + _validate_model_id(model_id) try: return self._client.body('get', url='models/{0}'.format(model_id)) except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) + + def delete_model(self, model_id): + _validate_model_id(model_id) + try: + self._client.body('delete', url='models/{0}'.format(model_id)) + except requests.exceptions.RequestException as error: + raise _utils.handle_platform_error_from_requests(error) diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 85edaa4a1..b0aadffc6 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -33,7 +33,8 @@ 'displayName': DISPLAY_NAME_1 } MODEL_1 = mlkit.Model(MODEL_JSON_1) -_DEFAULT_RESPONSE = json.dumps(MODEL_JSON_1) +_DEFAULT_GET_RESPONSE = json.dumps(MODEL_JSON_1) +_EMPTY_RESPONSE = json.dumps({}) ERROR_CODE = 404 ERROR_MSG = 'The resource was not found' @@ -47,6 +48,37 @@ } _ERROR_RESPONSE = json.dumps(ERROR_JSON) +invalid_model_id_args = [ + ('', ValueError, 'Model ID format is invalid.'), + ('&_*#@:/?', ValueError, 'Model ID format is invalid.'), + (None, TypeError, 'Model ID must be a string.'), + (12345, TypeError, 'Model ID must be a string.'), +] + +def check_error(err, err_type, msg): + assert isinstance(err, err_type) + assert str(err) == msg + + +def check_firebase_error(err, code, status, msg): + assert isinstance(err, exceptions.FirebaseError) + assert err.code == code + assert err.http_response is not None + assert err.http_response.status_code == status + assert str(err) == msg + + +def instrument_mlkit_service(app=None, status=200, payload=None): + if not app: + app = firebase_admin.get_app() + mlkit_service = mlkit._get_mlkit_service(app) + recorder = [] + mlkit_service._client.session.mount( + 'https://mlkit.googleapis.com', + testutils.MockAdapter(payload, status, recorder) + ) + return recorder + class TestGetModel(object): """Tests mlkit.get_model.""" @@ -60,71 +92,33 @@ def teardown_class(cls): testutils.cleanup_apps() @staticmethod - def check_error(err, err_type, msg): - assert isinstance(err, err_type) - assert str(err) == msg - - @staticmethod - def check_firebase_error(err, code, status, msg): - assert isinstance(err, exceptions.FirebaseError) - assert err.code == code - assert err.http_response is not None - assert err.http_response.status_code == status - assert str(err) == msg - - def _get_url(self, project_id, model_id): + def _url(project_id, model_id): return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) - def _instrument_mlkit_service(self, app=None, status=200, payload=_DEFAULT_RESPONSE): - if not app: - app = firebase_admin.get_app() - mlkit_service = mlkit._get_mlkit_service(app) - recorder = [] - mlkit_service._client.session.mount( - 'https://mlkit.googleapis.com', - testutils.MockAdapter(payload, status, recorder) - ) - return mlkit_service, recorder - def test_get_model(self): - _, recorder = self._instrument_mlkit_service() + recorder = instrument_mlkit_service(status=200, payload=_DEFAULT_GET_RESPONSE) model = mlkit.get_model(MODEL_ID_1) assert len(recorder) == 1 assert recorder[0].method == 'GET' - assert recorder[0].url == self._get_url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1) assert model == MODEL_1 assert model._data['name'] == MODEL_NAME_1 assert model._data['displayName'] == DISPLAY_NAME_1 - def test_get_model_validation_errors(self): - #Empty model-id - with pytest.raises(ValueError) as err: - mlkit.get_model('') - self.check_error(err.value, ValueError, 'Model ID format is invalid.') - - #None model-id - with pytest.raises(TypeError) as err: - mlkit.get_model(None) - self.check_error(err.value, TypeError, 'Model ID must be a string.') - - #Wrong type - with pytest.raises(TypeError) as err: - mlkit.get_model(12345) - self.check_error(err.value, TypeError, 'Model ID must be a string.') - - #Invalid characters - with pytest.raises(ValueError) as err: - mlkit.get_model('&_*#@:/?') - self.check_error(err.value, ValueError, 'Model ID format is invalid.') + @pytest.mark.parametrize('model_id, exc_type, error_message', invalid_model_id_args) + def test_get_model_validation_errors(self, model_id, exc_type, error_message): + with pytest.raises(exc_type) as err: + mlkit.get_model(model_id) + check_error(err.value, exc_type, error_message) def test_get_model_error(self): - _, recorder = self._instrument_mlkit_service(status=404, payload=_ERROR_RESPONSE) + recorder = instrument_mlkit_service(status=404, payload=_ERROR_RESPONSE) with pytest.raises(exceptions.NotFoundError) as err: mlkit.get_model(MODEL_ID_1) - self.check_firebase_error(err.value, ERROR_STATUS, ERROR_CODE, ERROR_MSG) + check_firebase_error(err.value, ERROR_STATUS, ERROR_CODE, ERROR_MSG) assert len(recorder) == 1 assert recorder[0].method == 'GET' - assert recorder[0].url == self._get_url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].url == self._url(PROJECT_ID, MODEL_ID_1) def test_no_project_id(self): def evaluate(): @@ -132,3 +126,47 @@ def evaluate(): with pytest.raises(ValueError): mlkit.get_model(MODEL_ID_1, app) testutils.run_without_project_id(evaluate) + +class TestDeleteModel(object): + """Tests mlkit.delete_model.""" + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def _url(project_id, model_id): + return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + + def test_delete_model(self): + recorder = instrument_mlkit_service(status=200, payload=_EMPTY_RESPONSE) + mlkit.delete_model(MODEL_ID_1) # no response for delete + assert len(recorder) == 1 + assert recorder[0].method == 'DELETE' + assert recorder[0].url == TestDeleteModel._url(PROJECT_ID, MODEL_ID_1) + + @pytest.mark.parametrize('model_id, exc_type, error_message', invalid_model_id_args) + def test_delete_model_validation_errors(self, model_id, exc_type, error_message): + with pytest.raises(exc_type) as err: + mlkit.delete_model(model_id) + check_error(err.value, exc_type, error_message) + + def test_delete_model_error(self): + recorder = instrument_mlkit_service(status=404, payload=_ERROR_RESPONSE) + with pytest.raises(exceptions.NotFoundError) as err: + mlkit.delete_model(MODEL_ID_1) + check_firebase_error(err.value, ERROR_STATUS, ERROR_CODE, ERROR_MSG) + assert len(recorder) == 1 + assert recorder[0].method == 'DELETE' + assert recorder[0].url == self._url(PROJECT_ID, MODEL_ID_1) + + def test_no_project_id(self): + def evaluate(): + app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') + with pytest.raises(ValueError): + mlkit.delete_model(MODEL_ID_1, app) + testutils.run_without_project_id(evaluate)