diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index 95ed2c414..fb6e32932 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -106,6 +106,27 @@ def handle_platform_error_from_requests(error, handle_func=None): return exc if exc else _handle_func_requests(error, message, error_dict) +def handle_operation_error(error): + """Constructs a ``FirebaseError`` from the given operation error. + + Args: + error: An error returned by a long running operation. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code. + """ + if not isinstance(error, dict): + return exceptions.UnknownError( + message='Unknown error while making a remote service call: {0}'.format(error), + cause=error) + + status_code = error.get('code') + message = error.get('message') + error_code = _http_status_to_error_code(status_code) + err_type = _error_code_to_exception_type(error_code) + return err_type(message=message) + + def _handle_func_requests(error, message, error_dict): """Constructs a ``FirebaseError`` from the given GCP error. diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 3f1a825f6..8cf8d1f7f 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -21,11 +21,14 @@ import datetime import numbers import re +import time import requests import six + from firebase_admin import _http_client from firebase_admin import _utils +from firebase_admin import exceptions _MLKIT_ATTRIBUTE = '_mlkit' @@ -36,6 +39,9 @@ _GCS_TFLITE_URI_PATTERN = re.compile(r'^gs://[a-z0-9_.-]{3,63}/.+') _RESOURCE_NAME_PATTERN = re.compile( r'^projects/(?P[^/]+)/models/(?P[A-Za-z0-9_-]{1,60})$') +_OPERATION_NAME_PATTERN = re.compile( + r'^operations/project/(?P[^/]+)/model/(?P[A-Za-z0-9_-]{1,60})' + + r'/operation/[^/]+$') def _get_mlkit_service(app): @@ -53,18 +59,60 @@ def _get_mlkit_service(app): return _utils.get_app_service(app, _MLKIT_ATTRIBUTE, _MLKitService) +def create_model(model, app=None): + """Creates a model in Firebase ML Kit. + + Args: + model: An mlkit.Model to create. + app: A Firebase app instance (or None to use the default app). + + Returns: + Model: The model that was created in Firebase ML Kit. + """ + mlkit_service = _get_mlkit_service(app) + return Model.from_dict(mlkit_service.create_model(model), app=app) + + def get_model(model_id, app=None): + """Gets a model from Firebase ML Kit. + + Args: + model_id: The id of the model to get. + app: A Firebase app instance (or None to use the default app). + + Returns: + Model: The requested model. + """ mlkit_service = _get_mlkit_service(app) - return Model.from_dict(mlkit_service.get_model(model_id)) + return Model.from_dict(mlkit_service.get_model(model_id), app=app) def list_models(list_filter=None, page_size=None, page_token=None, app=None): + """Lists models from Firebase ML Kit. + + Args: + list_filter: a list filter string such as "tags:'tag_1'". None will return all models. + page_size: A number between 1 and 100 inclusive that specifies the maximum + number of models to return per page. None for default. + page_token: A next page token returned from a previous page of results. None + for first page of results. + app: A Firebase app instance (or None to use the default app). + + Returns: + ListModelsPage: A (filtered) list of models. + """ mlkit_service = _get_mlkit_service(app) return ListModelsPage( - mlkit_service.list_models, list_filter, page_size, page_token) + mlkit_service.list_models, list_filter, page_size, page_token, app=app) def delete_model(model_id, app=None): + """Deletes a model from Firebase ML Kit. + + Args: + model_id: The id of the model you wish to delete. + app: A Firebase app instance (or None to use the default app). + """ mlkit_service = _get_mlkit_service(app) mlkit_service.delete_model(model_id) @@ -78,6 +126,7 @@ class Model(object): model_format: A subclass of ModelFormat. (e.g. TFLiteFormat) Specifies the model details. """ def __init__(self, display_name=None, tags=None, model_format=None): + self._app = None # Only needed for wait_for_unlo self._data = {} self._model_format = None @@ -89,7 +138,7 @@ def __init__(self, display_name=None, tags=None, model_format=None): self.model_format = model_format @classmethod - def from_dict(cls, data): + def from_dict(cls, data, app=None): data_copy = dict(data) tflite_format = None tflite_format_data = data_copy.pop('tfliteModel', None) @@ -97,8 +146,14 @@ def from_dict(cls, data): tflite_format = TFLiteFormat.from_dict(tflite_format_data) model = Model(model_format=tflite_format) model._data = data_copy # pylint: disable=protected-access + model._app = app # pylint: disable=protected-access return model + def _update_from_dict(self, data): + copy = Model.from_dict(data) + self.model_format = copy.model_format + self._data = copy._data # pylint: disable=protected-access + def __eq__(self, other): if isinstance(other, self.__class__): # pylint: disable=protected-access @@ -173,6 +228,26 @@ def locked(self): return bool(self._data.get('activeOperations') and len(self._data.get('activeOperations')) > 0) + def wait_for_unlocked(self, max_time_seconds=None): + """Waits for the model to be unlocked. (All active operations complete) + + Args: + max_time_seconds: The maximum number of seconds to wait for the model to unlock. + (None for no limit) + + Raises: + exceptions.DeadlineExceeded: If max_time_seconds passed and the model is still locked. + """ + if not self.locked: + return + mlkit_service = _get_mlkit_service(self._app) + op_name = self._data.get('activeOperations')[0].get('name') + model_dict = mlkit_service.handle_operation( + mlkit_service.get_operation(op_name), + wait_for_operation=True, + max_time_seconds=max_time_seconds) + self._update_from_dict(model_dict) + @property def model_format(self): return self._model_format @@ -296,17 +371,20 @@ class ListModelsPage(object): ``iterate_all()`` can be used to iterate through all the models in the Firebase project starting from this page. """ - def __init__(self, list_models_func, list_filter, page_size, page_token): + def __init__(self, list_models_func, list_filter, page_size, page_token, app): self._list_models_func = list_models_func self._list_filter = list_filter self._page_size = page_size self._page_token = page_token + self._app = app self._list_response = list_models_func(list_filter, page_size, page_token) @property def models(self): """A list of Models from this page.""" - return [Model.from_dict(model) for model in self._list_response.get('models', [])] + return [ + Model.from_dict(model, app=self._app) for model in self._list_response.get('models', []) + ] @property def list_filter(self): @@ -333,7 +411,8 @@ def get_next_page(self): self._list_models_func, self._list_filter, self._page_size, - self.next_page_token) + self.next_page_token, + self._app) return None def iterate_all(self): @@ -390,11 +469,25 @@ def _validate_and_parse_name(name): return matcher.group('project_id'), matcher.group('model_id') +def _validate_model(model): + if not isinstance(model, Model): + raise TypeError('Model must be an mlkit.Model.') + if not model.display_name: + raise ValueError('Model must have a display name.') + + def _validate_model_id(model_id): if not _MODEL_ID_PATTERN.match(model_id): raise ValueError('Model ID format is invalid.') +def _validate_and_parse_operation_name(op_name): + matcher = _OPERATION_NAME_PATTERN.match(op_name) + if not matcher: + raise ValueError('Operation name format is invalid.') + return matcher.group('project_id'), matcher.group('model_id') + + def _validate_display_name(display_name): if not _DISPLAY_NAME_PATTERN.match(display_name): raise ValueError('Display name format is invalid.') @@ -417,11 +510,13 @@ def _validate_gcs_tflite_uri(uri): raise ValueError('GCS TFLite URI format is invalid.') return uri + def _validate_model_format(model_format): if not isinstance(model_format, ModelFormat): raise TypeError('Model format must be a ModelFormat object.') return model_format + def _validate_list_filter(list_filter): if list_filter is not None: if not isinstance(list_filter, six.string_types): @@ -448,6 +543,9 @@ class _MLKitService(object): """Firebase MLKit service.""" PROJECT_URL = 'https://mlkit.googleapis.com/v1beta1/projects/{0}/' + OPERATION_URL = 'https://mlkit.googleapis.com/v1beta1/' + POLL_EXPONENTIAL_BACKOFF_FACTOR = 1.5 + POLL_BASE_WAIT_TIME_SECONDS = 3 def __init__(self, app): project_id = app.project_id @@ -459,6 +557,82 @@ def __init__(self, app): self._client = _http_client.JsonHttpClient( credential=app.credential.get_credential(), base_url=self._project_url) + self._operation_client = _http_client.JsonHttpClient( + credential=app.credential.get_credential(), + base_url=_MLKitService.OPERATION_URL) + + def get_operation(self, op_name): + _validate_and_parse_operation_name(op_name) + try: + return self._operation_client.body('get', url=op_name) + except requests.exceptions.RequestException as error: + raise _utils.handle_platform_error_from_requests(error) + + def _exponential_backoff(self, current_attempt, stop_time): + """Sleeps for the appropriate amount of time. Or throws deadline exceeded.""" + delay_factor = pow(_MLKitService.POLL_EXPONENTIAL_BACKOFF_FACTOR, current_attempt) + wait_time_seconds = delay_factor * _MLKitService.POLL_BASE_WAIT_TIME_SECONDS + + if stop_time is not None: + max_seconds_left = (stop_time - datetime.datetime.now()).total_seconds() + if max_seconds_left < 1: # allow a bit of time for rpc + raise exceptions.DeadlineExceededError('Polling max time exceeded.') + else: + wait_time_seconds = min(wait_time_seconds, max_seconds_left - 1) + time.sleep(wait_time_seconds) + + + def handle_operation(self, operation, wait_for_operation=False, max_time_seconds=None): + """Handles long running operations. + + Args: + operation: The operation to handle. + wait_for_operation: Should we allow polling for the operation to complete. + If no polling is requested, a locked model will be returned instead. + max_time_seconds: The maximum seconds to try polling for operation complete. + (None for no limit) + + Returns: + dict: A dictionary of the returned model properties. + + Raises: + TypeError: if the operation is not a dictionary. + ValueError: If the operation is malformed. + err: If the operation exceeds polling attempts or stop_time + """ + if not isinstance(operation, dict): + raise TypeError('Operation must be a dictionary.') + op_name = operation.get('name') + _, model_id = _validate_and_parse_operation_name(op_name) + + current_attempt = 0 + start_time = datetime.datetime.now() + stop_time = (None if max_time_seconds is None else + start_time + datetime.timedelta(seconds=max_time_seconds)) + while wait_for_operation and not operation.get('done'): + # We just got this operation. Wait before getting another + # so we don't exceed the GetOperation maximum request rate. + self._exponential_backoff(current_attempt, stop_time) + operation = self.get_operation(op_name) + current_attempt += 1 + + if operation.get('done'): + if operation.get('response'): + return operation.get('response') + elif operation.get('error'): + raise _utils.handle_operation_error(operation.get('error')) + + # If the operation is not complete or timed out, return a (locked) model instead + return get_model(model_id).as_dict() + + + def create_model(self, model): + _validate_model(model) + try: + return self.handle_operation( + self._client.body('post', url='models', json=model.as_dict())) + except requests.exceptions.RequestException as error: + raise _utils.handle_platform_error_from_requests(error) def get_model(self, model_id): _validate_model_id(model_id) diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index c20982a2b..78afbdf49 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -29,16 +29,24 @@ PAGE_TOKEN = 'pageToken' NEXT_PAGE_TOKEN = 'nextPageToken' CREATE_TIME_SECONDS = 1566426374 +CREATE_TIME_SECONDS_2 = 1566426385 CREATE_TIME_JSON = { 'seconds': CREATE_TIME_SECONDS } CREATE_TIME_DATETIME = datetime.datetime.fromtimestamp(CREATE_TIME_SECONDS) +CREATE_TIME_JSON_2 = { + 'seconds': CREATE_TIME_SECONDS_2 +} UPDATE_TIME_SECONDS = 1566426678 +UPDATE_TIME_SECONDS_2 = 1566426691 UPDATE_TIME_JSON = { 'seconds': UPDATE_TIME_SECONDS } UPDATE_TIME_DATETIME = datetime.datetime.fromtimestamp(UPDATE_TIME_SECONDS) +UPDATE_TIME_JSON_2 = { + 'seconds': UPDATE_TIME_SECONDS_2 +} ETAG = '33a64df551425fcc55e4d42a148795d9f25f89d4' MODEL_HASH = '987987a98b98798d098098e09809fc0893897' TAG_1 = 'Tag1' @@ -86,8 +94,9 @@ } } +OPERATION_NAME_1 = 'operations/project/{0}/model/{1}/operation/123'.format(PROJECT_ID, MODEL_ID_1) OPERATION_NOT_DONE_JSON_1 = { - 'name': 'operations/project/{0}/model/{1}/operation/123'.format(PROJECT_ID, MODEL_ID_1), + 'name': OPERATION_NAME_1, 'metadata': { '@type': 'type.googleapis.com/google.firebase.ml.v1beta1.ModelOperationMetadata', 'name': 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1), @@ -113,6 +122,64 @@ } TFLITE_FORMAT_2 = mlkit.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_2) +CREATED_MODEL_JSON_1 = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_1, + 'createTime': CREATE_TIME_JSON, + 'updateTime': UPDATE_TIME_JSON, + 'state': MODEL_STATE_ERROR_JSON, + 'etag': ETAG, + 'modelHash': MODEL_HASH, + 'tags': TAGS, +} +CREATED_MODEL_1 = mlkit.Model.from_dict(CREATED_MODEL_JSON_1) + +LOCKED_MODEL_JSON_1 = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_1, + 'createTime': CREATE_TIME_JSON, + 'updateTime': UPDATE_TIME_JSON, + 'tags': TAGS, + 'activeOperations': [OPERATION_NOT_DONE_JSON_1] +} + +LOCKED_MODEL_JSON_2 = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_2, + 'createTime': CREATE_TIME_JSON_2, + 'updateTime': UPDATE_TIME_JSON_2, + 'tags': TAGS_2, + 'activeOperations': [OPERATION_NOT_DONE_JSON_1] +} + +OPERATION_DONE_MODEL_JSON_1 = { + 'name': OPERATION_NAME_1, + 'done': True, + 'response': CREATED_MODEL_JSON_1 +} + +OPERATION_MALFORMED_JSON_1 = { + 'name': OPERATION_NAME_1, + 'done': True, + # if done is true then either response or error should be populated +} + +OPERATION_MISSING_NAME = { + 'done': False +} + +OPERATION_ERROR_CODE = 400 +OPERATION_ERROR_MSG = "Invalid argument" +OPERATION_ERROR_EXPECTED_STATUS = 'INVALID_ARGUMENT' +OPERATION_ERROR_JSON_1 = { + 'name': OPERATION_NAME_1, + 'done': True, + 'error': { + 'code': OPERATION_ERROR_CODE, + 'message': OPERATION_ERROR_MSG, + } +} + FULL_MODEL_ERR_STATE_LRO_JSON = { 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_1, @@ -135,9 +202,22 @@ 'tags': TAGS, 'tfliteModel': TFLITE_FORMAT_JSON } +FULL_MODEL_PUBLISHED = mlkit.Model.from_dict(FULL_MODEL_PUBLISHED_JSON) +OPERATION_DONE_FULL_MODEL_PUBLISHED_JSON = { + 'name': OPERATION_NAME_1, + 'done': True, + 'response': FULL_MODEL_PUBLISHED_JSON +} EMPTY_RESPONSE = json.dumps({}) +OPERATION_NOT_DONE_RESPONSE = json.dumps(OPERATION_NOT_DONE_JSON_1) +OPERATION_DONE_RESPONSE = json.dumps(OPERATION_DONE_MODEL_JSON_1) +OPERATION_DONE_PUBLISHED_RESPONSE = json.dumps(OPERATION_DONE_FULL_MODEL_PUBLISHED_JSON) +OPERATION_ERROR_RESPONSE = json.dumps(OPERATION_ERROR_JSON_1) +OPERATION_MALFORMED_RESPONSE = json.dumps(OPERATION_MALFORMED_JSON_1) +OPERATION_MISSING_NAME_RESPONSE = json.dumps(OPERATION_MISSING_NAME) DEFAULT_GET_RESPONSE = json.dumps(MODEL_JSON_1) +LOCKED_MODEL_2_RESPONSE = json.dumps(LOCKED_MODEL_JSON_2) NO_MODELS_LIST_RESPONSE = json.dumps({}) DEFAULT_LIST_RESPONSE = json.dumps({ 'models': [MODEL_JSON_1, MODEL_JSON_2], @@ -185,13 +265,25 @@ invalid_string_or_none_args = [0, -1, 4.2, 0x10, False, list(), dict()] -def check_error(err, err_type, msg=None): +# For validation type errors +def check_error(excinfo, err_type, msg=None): + err = excinfo.value assert isinstance(err, err_type) if msg: assert str(err) == msg -def check_firebase_error(err, code, status, msg): +# For errors that are returned in an operation +def check_operation_error(excinfo, code, msg): + err = excinfo.value + assert isinstance(err, exceptions.FirebaseError) + assert err.code == code + assert str(err) == msg + + +# For rpc errors +def check_firebase_error(excinfo, code, status, msg): + err = excinfo.value assert isinstance(err, exceptions.FirebaseError) assert err.code == code assert err.http_response is not None @@ -199,20 +291,43 @@ def check_firebase_error(err, code, status, msg): assert str(err) == msg -def instrument_mlkit_service(app=None, status=200, payload=None): +def instrument_mlkit_service(status=200, payload=None, operations=False, app=None): if not app: app = firebase_admin.get_app() mlkit_service = mlkit._get_mlkit_service(app) recorder = [] - mlkit_service._client.session.mount( - 'https://mlkit.googleapis.com', - testutils.MockAdapter(payload, status, recorder) - ) + session_url = 'https://mlkit.googleapis.com/v1beta1/' + + if isinstance(status, list): + adapter = testutils.MockMultiRequestAdapter + else: + adapter = testutils.MockAdapter + + if operations: + mlkit_service._operation_client.session.mount( + session_url, adapter(payload, status, recorder)) + else: + mlkit_service._client.session.mount( + session_url, adapter(payload, status, recorder)) return recorder class TestModel(object): """Tests mlkit.Model class.""" + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def _op_url(project_id, model_id): + return BASE_URL + \ + 'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id) def test_model_success_err_state_lro(self): model = mlkit.Model.from_dict(FULL_MODEL_ERR_STATE_LRO_JSON) @@ -297,9 +412,9 @@ def test_model_format_setters(self): (12345, TypeError) ]) def test_model_display_name_validation_errors(self, display_name, exc_type): - with pytest.raises(exc_type) as err: + with pytest.raises(exc_type) as excinfo: mlkit.Model(display_name=display_name) - check_error(err.value, exc_type) + check_error(excinfo, exc_type) @pytest.mark.parametrize('tags, exc_type, error_message', [ ('tag1', TypeError, 'Tags must be a list of strings.'), @@ -311,9 +426,9 @@ def test_model_display_name_validation_errors(self, display_name, exc_type): 'tag2'], ValueError, 'Tag format is invalid.') ]) def test_model_tags_validation_errors(self, tags, exc_type, error_message): - with pytest.raises(exc_type) as err: + with pytest.raises(exc_type) as excinfo: mlkit.Model(tags=tags) - check_error(err.value, exc_type, error_message) + check_error(excinfo, exc_type, error_message) @pytest.mark.parametrize('model_format', [ 123, @@ -323,9 +438,9 @@ def test_model_tags_validation_errors(self, tags, exc_type, error_message): True ]) def test_model_format_validation_errors(self, model_format): - with pytest.raises(TypeError) as err: + with pytest.raises(TypeError) as excinfo: mlkit.Model(model_format=model_format) - check_error(err.value, TypeError, 'Model format must be a ModelFormat object.') + check_error(excinfo, TypeError, 'Model format must be a ModelFormat object.') @pytest.mark.parametrize('model_source', [ 123, @@ -335,9 +450,9 @@ def test_model_format_validation_errors(self, model_format): True ]) def test_model_source_validation_errors(self, model_source): - with pytest.raises(TypeError) as err: + with pytest.raises(TypeError) as excinfo: mlkit.TFLiteFormat(model_source=model_source) - check_error(err.value, TypeError, 'Model source must be a TFLiteModelSource object.') + check_error(excinfo, TypeError, 'Model source must be a TFLiteModelSource object.') @pytest.mark.parametrize('uri, exc_type', [ (123, TypeError), @@ -351,9 +466,153 @@ def test_model_source_validation_errors(self, model_source): ValueError) ]) def test_gcs_tflite_source_validation_errors(self, uri, exc_type): - with pytest.raises(exc_type) as err: + with pytest.raises(exc_type) as excinfo: mlkit.TFLiteGCSModelSource(gcs_tflite_uri=uri) - check_error(err.value, exc_type) + check_error(excinfo, exc_type) + + def test_wait_for_unlocked_not_locked(self): + model = mlkit.Model(display_name="not_locked") + model.wait_for_unlocked() + + def test_wait_for_unlocked(self): + recorder = instrument_mlkit_service(status=200, + operations=True, + payload=OPERATION_DONE_PUBLISHED_RESPONSE) + model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_1) + model.wait_for_unlocked() + assert model == FULL_MODEL_PUBLISHED + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == TestModel._op_url(PROJECT_ID, MODEL_ID_1) + + def test_wait_for_unlocked_timeout(self): + recorder = instrument_mlkit_service( + status=200, operations=True, payload=OPERATION_NOT_DONE_RESPONSE) + mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 3 # longer so timeout applies immediately + model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_1) + with pytest.raises(Exception) as excinfo: + model.wait_for_unlocked(max_time_seconds=0.1) + check_error(excinfo, exceptions.DeadlineExceededError, 'Polling max time exceeded.') + assert len(recorder) == 1 + + +class TestCreateModel(object): + """Tests mlkit.create_model.""" + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def _url(project_id): + return BASE_URL + 'projects/{0}/models'.format(project_id) + + @staticmethod + def _op_url(project_id, model_id): + return BASE_URL + \ + 'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id) + + @staticmethod + def _get_url(project_id, model_id): + return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + + def test_immediate_done(self): + instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE) + model = mlkit.create_model(MODEL_1) + assert model == CREATED_MODEL_1 + + def test_returns_locked(self): + recorder = instrument_mlkit_service( + status=[200, 200], + payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE]) + expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) + model = mlkit.create_model(MODEL_1) + + assert model == expected_model + assert len(recorder) == 2 + assert recorder[0].method == 'POST' + assert recorder[0].url == TestCreateModel._url(PROJECT_ID) + assert recorder[1].method == 'GET' + assert recorder[1].url == TestCreateModel._get_url(PROJECT_ID, MODEL_ID_1) + + def test_operation_error(self): + instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE) + with pytest.raises(Exception) as excinfo: + mlkit.create_model(MODEL_1) + # The http request succeeded, the operation returned contains a create failure + check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) + + def test_malformed_operation(self): + recorder = instrument_mlkit_service( + status=[200, 200], + payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE]) + expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) + model = mlkit.create_model(MODEL_1) + assert model == expected_model + assert len(recorder) == 2 + assert recorder[0].method == 'POST' + assert recorder[0].url == TestCreateModel._url(PROJECT_ID) + assert recorder[1].method == 'GET' + assert recorder[1].url == TestCreateModel._get_url(PROJECT_ID, MODEL_ID_1) + + def test_rpc_error_create(self): + create_recorder = instrument_mlkit_service( + status=400, payload=ERROR_RESPONSE_BAD_REQUEST) + with pytest.raises(Exception) as excinfo: + mlkit.create_model(MODEL_1) + check_firebase_error( + excinfo, + ERROR_STATUS_BAD_REQUEST, + ERROR_CODE_BAD_REQUEST, + ERROR_MSG_BAD_REQUEST + ) + assert len(create_recorder) == 1 + + @pytest.mark.parametrize('model', [ + 'abc', + 4.2, + list(), + dict(), + True, + -1, + 0, + None + ]) + def test_not_model(self, model): + with pytest.raises(Exception) as excinfo: + mlkit.create_model(model) + check_error(excinfo, TypeError, 'Model must be an mlkit.Model.') + + def test_missing_display_name(self): + with pytest.raises(Exception) as excinfo: + mlkit.create_model(mlkit.Model.from_dict({})) + check_error(excinfo, ValueError, 'Model must have a display name.') + + def test_missing_op_name(self): + instrument_mlkit_service(status=200, payload=OPERATION_MISSING_NAME_RESPONSE) + with pytest.raises(Exception) as excinfo: + mlkit.create_model(MODEL_1) + check_error(excinfo, TypeError) + + @pytest.mark.parametrize('op_name', [ + 'abc', + '123', + 'projects/operations/project/1234/model/abc/operation/123', + 'operations/project/model/abc/operation/123', + 'operations/project/123/model/$#@/operation/123', + 'operations/project/1234/model/abc/operation/123/extrathing', + ]) + def test_invalid_op_name(self, op_name): + payload = json.dumps({'name': op_name}) + instrument_mlkit_service(status=200, payload=payload) + with pytest.raises(Exception) as excinfo: + mlkit.create_model(MODEL_1) + check_error(excinfo, ValueError, 'Operation name format is invalid.') class TestGetModel(object): @@ -383,16 +642,16 @@ def test_get_model(self): @pytest.mark.parametrize('model_id, exc_type', invalid_model_id_args) def test_get_model_validation_errors(self, model_id, exc_type): - with pytest.raises(exc_type) as err: + with pytest.raises(exc_type) as excinfo: mlkit.get_model(model_id) - check_error(err.value, exc_type) + check_error(excinfo, exc_type) def test_get_model_error(self): recorder = instrument_mlkit_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) - with pytest.raises(exceptions.NotFoundError) as err: + with pytest.raises(exceptions.NotFoundError) as excinfo: mlkit.get_model(MODEL_ID_1) check_firebase_error( - err.value, + excinfo, ERROR_STATUS_NOT_FOUND, ERROR_CODE_NOT_FOUND, ERROR_MSG_NOT_FOUND @@ -433,16 +692,16 @@ def test_delete_model(self): @pytest.mark.parametrize('model_id, exc_type', invalid_model_id_args) def test_delete_model_validation_errors(self, model_id, exc_type): - with pytest.raises(exc_type) as err: + with pytest.raises(exc_type) as excinfo: mlkit.delete_model(model_id) - check_error(err.value, exc_type) + check_error(excinfo, exc_type) def test_delete_model_error(self): recorder = instrument_mlkit_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) - with pytest.raises(exceptions.NotFoundError) as err: + with pytest.raises(exceptions.NotFoundError) as excinfo: mlkit.delete_model(MODEL_ID_1) check_firebase_error( - err.value, + excinfo, ERROR_STATUS_NOT_FOUND, ERROR_CODE_NOT_FOUND, ERROR_MSG_NOT_FOUND @@ -514,9 +773,9 @@ def test_list_models_with_all_args(self): @pytest.mark.parametrize('list_filter', invalid_string_or_none_args) def test_list_models_list_filter_validation(self, list_filter): - with pytest.raises(TypeError) as err: + with pytest.raises(TypeError) as excinfo: mlkit.list_models(list_filter=list_filter) - check_error(err.value, TypeError, 'List filter must be a string or None.') + check_error(excinfo, TypeError, 'List filter must be a string or None.') @pytest.mark.parametrize('page_size, exc_type, error_message', [ ('abc', TypeError, 'Page size must be a number or None.'), @@ -529,22 +788,22 @@ def test_list_models_list_filter_validation(self, list_filter): (mlkit._MAX_PAGE_SIZE + 1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG) ]) def test_list_models_page_size_validation(self, page_size, exc_type, error_message): - with pytest.raises(exc_type) as err: + with pytest.raises(exc_type) as excinfo: mlkit.list_models(page_size=page_size) - check_error(err.value, exc_type, error_message) + check_error(excinfo, exc_type, error_message) @pytest.mark.parametrize('page_token', invalid_string_or_none_args) def test_list_models_page_token_validation(self, page_token): - with pytest.raises(TypeError) as err: + with pytest.raises(TypeError) as excinfo: mlkit.list_models(page_token=page_token) - check_error(err.value, TypeError, 'Page token must be a string or None.') + check_error(excinfo, TypeError, 'Page token must be a string or None.') def test_list_models_error(self): recorder = instrument_mlkit_service(status=400, payload=ERROR_RESPONSE_BAD_REQUEST) - with pytest.raises(exceptions.InvalidArgumentError) as err: + with pytest.raises(exceptions.InvalidArgumentError) as excinfo: mlkit.list_models() check_firebase_error( - err.value, + excinfo, ERROR_STATUS_BAD_REQUEST, ERROR_CODE_BAD_REQUEST, ERROR_MSG_BAD_REQUEST