diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 0e42419c6..dd02ea8db 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -27,6 +27,7 @@ _MLKIT_ATTRIBUTE = '_mlkit' +_MAX_PAGE_SIZE = 100 def _get_mlkit_service(app): @@ -49,6 +50,12 @@ def get_model(model_id, app=None): return Model(mlkit_service.get_model(model_id)) +def list_models(list_filter=None, page_size=None, page_token=None, app=None): + mlkit_service = _get_mlkit_service(app) + return ListModelsPage( + mlkit_service.list_models, list_filter, page_size, page_token) + + def delete_model(model_id, app=None): mlkit_service = _get_mlkit_service(app) mlkit_service.delete_model(model_id) @@ -69,7 +76,107 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) - #TODO(ifielker): define the Model properties etc + @property + def name(self): + return self._data['name'] + + @property + def display_name(self): + return self._data['displayName'] + + #TODO(ifielker): define the rest of the Model properties etc + + +class ListModelsPage(object): + """Represents a page of models in a firebase project. + + Provides methods for traversing the models included in this page, as well as + retrieving subsequent pages of models. The iterator returned by + ``iterate_all()`` can be used to iterate through all the models in the + Firebase project starting from this page. + """ + def __init__(self, list_models_func, list_filter, page_size, page_token): + self._list_models_func = list_models_func + self._list_filter = list_filter + self._page_size = page_size + self._page_token = page_token + self._list_response = list_models_func(list_filter, page_size, page_token) + + @property + def models(self): + """A list of Models from this page.""" + return [Model(model) for model in self._list_response.get('models', [])] + + @property + def list_filter(self): + """The filter string used to filter the models.""" + return self._list_filter + + @property + def next_page_token(self): + return self._list_response.get('nextPageToken', '') + + @property + def has_next_page(self): + """A boolean indicating whether more pages are available.""" + return bool(self.next_page_token) + + def get_next_page(self): + """Retrieves the next page of models if available. + + Returns: + ListModelsPage: Next page of models, or None if this is the last page. + """ + if self.has_next_page: + return ListModelsPage( + self._list_models_func, + self._list_filter, + self._page_size, + self.next_page_token) + return None + + def iterate_all(self): + """Retrieves an iterator for Models. + + Returned iterator will iterate through all the models in the Firebase + project starting from this page. The iterator will never buffer more than + one page of models in memory at a time. + + Returns: + iterator: An iterator of Model instances. + """ + return _ModelIterator(self) + + +class _ModelIterator(object): + """An iterator that allows iterating over models, one at a time. + + This implementation loads a page of models into memory, and iterates on them. + When the whole page has been traversed, it loads another page. This class + never keeps more than one page of entries in memory. + """ + def __init__(self, current_page): + if not isinstance(current_page, ListModelsPage): + raise TypeError('Current page must be a ListModelsPage') + self._current_page = current_page + self._index = 0 + + def next(self): + if self._index == len(self._current_page.models): + if self._current_page.has_next_page: + self._current_page = self._current_page.get_next_page() + self._index = 0 + if self._index < len(self._current_page.models): + result = self._current_page.models[self._index] + self._index += 1 + return result + raise StopIteration + + def __next__(self): + return self.next() + + def __iter__(self): + return self def _validate_model_id(model_id): @@ -79,6 +186,28 @@ def _validate_model_id(model_id): raise ValueError('Model ID format is invalid.') +def _validate_list_filter(list_filter): + if list_filter is not None: + if not isinstance(list_filter, six.string_types): + raise TypeError('List filter must be a string or None.') + + +def _validate_page_size(page_size): + if page_size is not None: + if type(page_size) is not int: # pylint: disable=unidiomatic-typecheck + # Specifically type() to disallow boolean which is a subtype of int + raise TypeError('Page size must be a number or None.') + if page_size < 1 or page_size > _MAX_PAGE_SIZE: + raise ValueError('Page size must be a positive integer between ' + '1 and {0}'.format(_MAX_PAGE_SIZE)) + + +def _validate_page_token(page_token): + if page_token is not None: + if not isinstance(page_token, six.string_types): + raise TypeError('Page token must be a string or None.') + + class _MLKitService(object): """Firebase MLKit service.""" @@ -102,6 +231,23 @@ def get_model(self, model_id): except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) + def list_models(self, list_filter, page_size, page_token): + """ lists Firebase ML Kit models.""" + _validate_list_filter(list_filter) + _validate_page_size(page_size) + _validate_page_token(page_token) + payload = {} + if list_filter: + payload['list_filter'] = list_filter + if page_size: + payload['page_size'] = page_size + if page_token: + payload['page_token'] = page_token + try: + return self._client.body('get', url='models', json=payload) + except requests.exceptions.RequestException as error: + raise _utils.handle_platform_error_from_requests(error) + def delete_model(self, model_id): _validate_model_id(model_id) try: diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index b0aadffc6..e14bd4371 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -25,6 +25,8 @@ BASE_URL = 'https://mlkit.googleapis.com/v1beta1/' PROJECT_ID = 'myProject1' +PAGE_TOKEN = 'pageToken' +NEXT_PAGE_TOKEN = 'nextPageToken' MODEL_ID_1 = 'modelId1' MODEL_NAME_1 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1) DISPLAY_NAME_1 = 'displayName1' @@ -33,20 +35,62 @@ 'displayName': DISPLAY_NAME_1 } MODEL_1 = mlkit.Model(MODEL_JSON_1) -_DEFAULT_GET_RESPONSE = json.dumps(MODEL_JSON_1) -_EMPTY_RESPONSE = json.dumps({}) -ERROR_CODE = 404 -ERROR_MSG = 'The resource was not found' -ERROR_STATUS = 'NOT_FOUND' -ERROR_JSON = { +MODEL_ID_2 = 'modelId2' +MODEL_NAME_2 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_2) +DISPLAY_NAME_2 = 'displayName2' +MODEL_JSON_2 = { + 'name': MODEL_NAME_2, + 'displayName': DISPLAY_NAME_2 +} +MODEL_2 = mlkit.Model(MODEL_JSON_2) + +MODEL_ID_3 = 'modelId3' +MODEL_NAME_3 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_3) +DISPLAY_NAME_3 = 'displayName3' +MODEL_JSON_3 = { + 'name': MODEL_NAME_3, + 'displayName': DISPLAY_NAME_3 +} +MODEL_3 = mlkit.Model(MODEL_JSON_3) + +EMPTY_RESPONSE = json.dumps({}) +DEFAULT_GET_RESPONSE = json.dumps(MODEL_JSON_1) +NO_MODELS_LIST_RESPONSE = json.dumps({}) +DEFAULT_LIST_RESPONSE = json.dumps({ + 'models': [MODEL_JSON_1, MODEL_JSON_2], + 'nextPageToken': NEXT_PAGE_TOKEN +}) +LAST_PAGE_LIST_RESPONSE = json.dumps({ + 'models': [MODEL_JSON_3] +}) +ONE_PAGE_LIST_RESPONSE = json.dumps({ + 'models': [MODEL_JSON_1, MODEL_JSON_2, MODEL_JSON_3], +}) + +ERROR_CODE_NOT_FOUND = 404 +ERROR_MSG_NOT_FOUND = 'The resource was not found' +ERROR_STATUS_NOT_FOUND = 'NOT_FOUND' +ERROR_JSON_NOT_FOUND = { 'error': { - 'code': ERROR_CODE, - 'message': ERROR_MSG, - 'status': ERROR_STATUS + 'code': ERROR_CODE_NOT_FOUND, + 'message': ERROR_MSG_NOT_FOUND, + 'status': ERROR_STATUS_NOT_FOUND } } -_ERROR_RESPONSE = json.dumps(ERROR_JSON) +ERROR_RESPONSE_NOT_FOUND = json.dumps(ERROR_JSON_NOT_FOUND) + +ERROR_CODE_BAD_REQUEST = 400 +ERROR_MSG_BAD_REQUEST = 'Invalid Argument' +ERROR_STATUS_BAD_REQUEST = 'INVALID_ARGUMENT' +ERROR_JSON_BAD_REQUEST = { + 'error': { + 'code': ERROR_CODE_BAD_REQUEST, + 'message': ERROR_MSG_BAD_REQUEST, + 'status': ERROR_STATUS_BAD_REQUEST + } +} +ERROR_RESPONSE_BAD_REQUEST = json.dumps(ERROR_JSON_BAD_REQUEST) invalid_model_id_args = [ ('', ValueError, 'Model ID format is invalid.'), @@ -54,6 +98,20 @@ (None, TypeError, 'Model ID must be a string.'), (12345, TypeError, 'Model ID must be a string.'), ] +PAGE_SIZE_VALUE_ERROR_MSG = 'Page size must be a positive integer between ' \ + '1 and {0}'.format(mlkit._MAX_PAGE_SIZE) +invalid_page_size_args = [ + ('abc', TypeError, 'Page size must be a number or None.'), + (4.2, TypeError, 'Page size must be a number or None.'), + (list(), TypeError, 'Page size must be a number or None.'), + (dict(), TypeError, 'Page size must be a number or None.'), + (True, TypeError, 'Page size must be a number or None.'), + (-1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), + (0, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), + (mlkit._MAX_PAGE_SIZE + 1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG) +] +invalid_string_or_none_args = [0, -1, 4.2, 0x10, False, list(), dict()] + def check_error(err, err_type, msg): assert isinstance(err, err_type) @@ -96,14 +154,14 @@ def _url(project_id, model_id): return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) def test_get_model(self): - recorder = instrument_mlkit_service(status=200, payload=_DEFAULT_GET_RESPONSE) + recorder = instrument_mlkit_service(status=200, payload=DEFAULT_GET_RESPONSE) model = mlkit.get_model(MODEL_ID_1) assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1) assert model == MODEL_1 - assert model._data['name'] == MODEL_NAME_1 - assert model._data['displayName'] == DISPLAY_NAME_1 + assert model.name == MODEL_NAME_1 + assert model.display_name == DISPLAY_NAME_1 @pytest.mark.parametrize('model_id, exc_type, error_message', invalid_model_id_args) def test_get_model_validation_errors(self, model_id, exc_type, error_message): @@ -112,13 +170,18 @@ def test_get_model_validation_errors(self, model_id, exc_type, error_message): check_error(err.value, exc_type, error_message) def test_get_model_error(self): - recorder = instrument_mlkit_service(status=404, payload=_ERROR_RESPONSE) + recorder = instrument_mlkit_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) with pytest.raises(exceptions.NotFoundError) as err: mlkit.get_model(MODEL_ID_1) - check_firebase_error(err.value, ERROR_STATUS, ERROR_CODE, ERROR_MSG) + check_firebase_error( + err.value, + ERROR_STATUS_NOT_FOUND, + ERROR_CODE_NOT_FOUND, + ERROR_MSG_NOT_FOUND + ) assert len(recorder) == 1 assert recorder[0].method == 'GET' - assert recorder[0].url == self._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1) def test_no_project_id(self): def evaluate(): @@ -143,7 +206,7 @@ def _url(project_id, model_id): return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) def test_delete_model(self): - recorder = instrument_mlkit_service(status=200, payload=_EMPTY_RESPONSE) + recorder = instrument_mlkit_service(status=200, payload=EMPTY_RESPONSE) mlkit.delete_model(MODEL_ID_1) # no response for delete assert len(recorder) == 1 assert recorder[0].method == 'DELETE' @@ -156,10 +219,15 @@ def test_delete_model_validation_errors(self, model_id, exc_type, error_message) check_error(err.value, exc_type, error_message) def test_delete_model_error(self): - recorder = instrument_mlkit_service(status=404, payload=_ERROR_RESPONSE) + recorder = instrument_mlkit_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) with pytest.raises(exceptions.NotFoundError) as err: mlkit.delete_model(MODEL_ID_1) - check_firebase_error(err.value, ERROR_STATUS, ERROR_CODE, ERROR_MSG) + check_firebase_error( + err.value, + ERROR_STATUS_NOT_FOUND, + ERROR_CODE_NOT_FOUND, + ERROR_MSG_NOT_FOUND + ) assert len(recorder) == 1 assert recorder[0].method == 'DELETE' assert recorder[0].url == self._url(PROJECT_ID, MODEL_ID_1) @@ -170,3 +238,163 @@ def evaluate(): with pytest.raises(ValueError): mlkit.delete_model(MODEL_ID_1, app) testutils.run_without_project_id(evaluate) + + +class TestListModels(object): + """Tests mlkit.list_models.""" + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def _url(project_id): + return BASE_URL + 'projects/{0}/models'.format(project_id) + + @staticmethod + def _check_page(page, model_count): + assert isinstance(page, mlkit.ListModelsPage) + assert len(page.models) == model_count + for model in page.models: + assert isinstance(model, mlkit.Model) + + def test_list_models_no_args(self): + recorder = instrument_mlkit_service(status=200, payload=DEFAULT_LIST_RESPONSE) + models_page = mlkit.list_models() + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == TestListModels._url(PROJECT_ID) + TestListModels._check_page(models_page, 2) + assert models_page.has_next_page + assert models_page.next_page_token == NEXT_PAGE_TOKEN + assert models_page.models[0] == MODEL_1 + assert models_page.models[1] == MODEL_2 + + def test_list_models_with_all_args(self): + recorder = instrument_mlkit_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) + models_page = mlkit.list_models( + 'display_name=displayName3', + page_size=10, + page_token=PAGE_TOKEN) + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == TestListModels._url(PROJECT_ID) + assert json.loads(recorder[0].body.decode()) == { + 'list_filter': 'display_name=displayName3', + 'page_size': 10, + 'page_token': PAGE_TOKEN + } + assert isinstance(models_page, mlkit.ListModelsPage) + assert len(models_page.models) == 1 + assert models_page.models[0] == MODEL_3 + assert not models_page.has_next_page + + @pytest.mark.parametrize('list_filter', invalid_string_or_none_args) + def test_list_models_list_filter_validation(self, list_filter): + with pytest.raises(TypeError) as err: + mlkit.list_models(list_filter=list_filter) + check_error(err.value, TypeError, 'List filter must be a string or None.') + + @pytest.mark.parametrize('page_size, exc_type, error_message', invalid_page_size_args) + def test_list_models_page_size_validation(self, page_size, exc_type, error_message): + with pytest.raises(exc_type) as err: + mlkit.list_models(page_size=page_size) + check_error(err.value, exc_type, error_message) + + @pytest.mark.parametrize('page_token', invalid_string_or_none_args) + def test_list_models_page_token_validation(self, page_token): + with pytest.raises(TypeError) as err: + mlkit.list_models(page_token=page_token) + check_error(err.value, TypeError, 'Page token must be a string or None.') + + def test_list_models_error(self): + recorder = instrument_mlkit_service(status=400, payload=ERROR_RESPONSE_BAD_REQUEST) + with pytest.raises(exceptions.InvalidArgumentError) as err: + mlkit.list_models() + check_firebase_error( + err.value, + ERROR_STATUS_BAD_REQUEST, + ERROR_CODE_BAD_REQUEST, + ERROR_MSG_BAD_REQUEST + ) + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == TestListModels._url(PROJECT_ID) + + def test_no_project_id(self): + def evaluate(): + app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') + with pytest.raises(ValueError): + mlkit.list_models(app=app) + testutils.run_without_project_id(evaluate) + + def test_list_single_page(self): + recorder = instrument_mlkit_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) + models_page = mlkit.list_models() + assert len(recorder) == 1 + assert models_page.next_page_token == '' + assert models_page.has_next_page is False + assert models_page.get_next_page() is None + models = [model for model in models_page.iterate_all()] + assert len(models) == 1 + + def test_list_multiple_pages(self): + # Page 1 + recorder = instrument_mlkit_service(status=200, payload=DEFAULT_LIST_RESPONSE) + page = mlkit.list_models() + assert len(recorder) == 1 + assert len(page.models) == 2 + assert page.next_page_token == NEXT_PAGE_TOKEN + assert page.has_next_page is True + + # Page 2 + recorder = instrument_mlkit_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) + page_2 = page.get_next_page() + assert len(recorder) == 1 + assert len(page_2.models) == 1 + assert page_2.next_page_token == '' + assert page_2.has_next_page is False + assert page_2.get_next_page() is None + + def test_list_models_paged_iteration(self): + # Page 1 + recorder = instrument_mlkit_service(status=200, payload=DEFAULT_LIST_RESPONSE) + page = mlkit.list_models() + assert page.next_page_token == NEXT_PAGE_TOKEN + assert page.has_next_page is True + iterator = page.iterate_all() + for index in range(2): + model = next(iterator) + assert model.display_name == 'displayName{0}'.format(index+1) + assert len(recorder) == 1 + + # Page 2 + recorder = instrument_mlkit_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) + model = next(iterator) + assert model.display_name == DISPLAY_NAME_3 + with pytest.raises(StopIteration): + next(iterator) + + def test_list_models_stop_iteration(self): + recorder = instrument_mlkit_service(status=200, payload=ONE_PAGE_LIST_RESPONSE) + page = mlkit.list_models() + assert len(recorder) == 1 + assert len(page.models) == 3 + iterator = page.iterate_all() + models = [model for model in iterator] + assert len(page.models) == 3 + with pytest.raises(StopIteration): + next(iterator) + assert len(models) == 3 + + def test_list_models_no_models(self): + recorder = instrument_mlkit_service(status=200, payload=NO_MODELS_LIST_RESPONSE) + page = mlkit.list_models() + assert len(recorder) == 1 + assert len(page.models) == 0 + models = [model for model in page.iterate_all()] + assert len(models) == 0