diff --git a/.travis.yml b/.travis.yml index 4db3c3708..c89a1db76 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,7 +16,6 @@ before_install: - nvm install 8 && npm install -g firebase-tools script: - pytest - - firebase emulators:exec --only database --project fake-project-id 'pytest integration/test_db.py' cache: pip: true npm: true diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 9271fd668..e86a827e0 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -18,8 +18,77 @@ deleting, publishing and unpublishing Firebase ML Kit models. """ +import re +import requests +import six + +from firebase_admin import _http_client +from firebase_admin import _utils + + +_MLKIT_ATTRIBUTE = '_mlkit' + + +def _get_mlkit_service(app): + """ Returns an _MLKitService instance for an App. + + Args: + app: A Firebase App instance (or None to use the default App). + + Returns: + _MLKitService: An _MLKitService for the specified App instance. + + Raises: + ValueError: If the app argument is invalid. + """ + return _utils.get_app_service(app, _MLKIT_ATTRIBUTE, _MLKitService) + + +def get_model(model_id, app=None): + mlkit_service = _get_mlkit_service(app) + return Model(mlkit_service.get_model(model_id)) + + +class Model(object): + """A Firebase ML Kit Model object.""" + def __init__(self, data): + """Created from a data dictionary.""" + self._data = data + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self._data == other._data # pylint: disable=protected-access + else: + return False + + def __ne__(self, other): + return not self.__eq__(other) + + #TODO(ifielker): define the Model properties etc + + class _MLKitService(object): """Firebase MLKit service.""" - BASE_URL = 'https://mlkit.googleapis.com' - PROJECT_URL = 'https://mlkit.googleapis.com/projects/{0}/' + PROJECT_URL = 'https://mlkit.googleapis.com/v1beta1/projects/{0}/' + + def __init__(self, app): + project_id = app.project_id + if not project_id: + raise ValueError( + 'Project ID is required to access MLKit service. Either set the ' + 'projectId option, or use service account credentials.') + self._project_url = _MLKitService.PROJECT_URL.format(project_id) + self._client = _http_client.JsonHttpClient( + credential=app.credential.get_credential(), + base_url=self._project_url) + + def get_model(self, model_id): + if not isinstance(model_id, six.string_types): + raise TypeError('Model ID must be a string.') + if not re.match(r'^[A-Za-z0-9_-]{1,60}$', model_id): + raise ValueError('Model ID format is invalid.') + try: + return self._client.body('get', url='models/{0}'.format(model_id)) + except requests.exceptions.RequestException as error: + raise _utils.handle_platform_error_from_requests(error) diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py new file mode 100644 index 000000000..85edaa4a1 --- /dev/null +++ b/tests/test_mlkit.py @@ -0,0 +1,134 @@ +# Copyright 2019 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test cases for the firebase_admin.mlkit module.""" + +import json +import pytest + +import firebase_admin +from firebase_admin import exceptions +from firebase_admin import mlkit +from tests import testutils + +BASE_URL = 'https://mlkit.googleapis.com/v1beta1/' + +PROJECT_ID = 'myProject1' +MODEL_ID_1 = 'modelId1' +MODEL_NAME_1 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1) +DISPLAY_NAME_1 = 'displayName1' +MODEL_JSON_1 = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_1 +} +MODEL_1 = mlkit.Model(MODEL_JSON_1) +_DEFAULT_RESPONSE = json.dumps(MODEL_JSON_1) + +ERROR_CODE = 404 +ERROR_MSG = 'The resource was not found' +ERROR_STATUS = 'NOT_FOUND' +ERROR_JSON = { + 'error': { + 'code': ERROR_CODE, + 'message': ERROR_MSG, + 'status': ERROR_STATUS + } +} +_ERROR_RESPONSE = json.dumps(ERROR_JSON) + + +class TestGetModel(object): + """Tests mlkit.get_model.""" + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def check_error(err, err_type, msg): + assert isinstance(err, err_type) + assert str(err) == msg + + @staticmethod + def check_firebase_error(err, code, status, msg): + assert isinstance(err, exceptions.FirebaseError) + assert err.code == code + assert err.http_response is not None + assert err.http_response.status_code == status + assert str(err) == msg + + def _get_url(self, project_id, model_id): + return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + + def _instrument_mlkit_service(self, app=None, status=200, payload=_DEFAULT_RESPONSE): + if not app: + app = firebase_admin.get_app() + mlkit_service = mlkit._get_mlkit_service(app) + recorder = [] + mlkit_service._client.session.mount( + 'https://mlkit.googleapis.com', + testutils.MockAdapter(payload, status, recorder) + ) + return mlkit_service, recorder + + def test_get_model(self): + _, recorder = self._instrument_mlkit_service() + model = mlkit.get_model(MODEL_ID_1) + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == self._get_url(PROJECT_ID, MODEL_ID_1) + assert model == MODEL_1 + assert model._data['name'] == MODEL_NAME_1 + assert model._data['displayName'] == DISPLAY_NAME_1 + + def test_get_model_validation_errors(self): + #Empty model-id + with pytest.raises(ValueError) as err: + mlkit.get_model('') + self.check_error(err.value, ValueError, 'Model ID format is invalid.') + + #None model-id + with pytest.raises(TypeError) as err: + mlkit.get_model(None) + self.check_error(err.value, TypeError, 'Model ID must be a string.') + + #Wrong type + with pytest.raises(TypeError) as err: + mlkit.get_model(12345) + self.check_error(err.value, TypeError, 'Model ID must be a string.') + + #Invalid characters + with pytest.raises(ValueError) as err: + mlkit.get_model('&_*#@:/?') + self.check_error(err.value, ValueError, 'Model ID format is invalid.') + + def test_get_model_error(self): + _, recorder = self._instrument_mlkit_service(status=404, payload=_ERROR_RESPONSE) + with pytest.raises(exceptions.NotFoundError) as err: + mlkit.get_model(MODEL_ID_1) + self.check_firebase_error(err.value, ERROR_STATUS, ERROR_CODE, ERROR_MSG) + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == self._get_url(PROJECT_ID, MODEL_ID_1) + + def test_no_project_id(self): + def evaluate(): + app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') + with pytest.raises(ValueError): + mlkit.get_model(MODEL_ID_1, app) + testutils.run_without_project_id(evaluate)