diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index dd02ea8db..3f1a825f6 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -18,6 +18,8 @@ deleting, publishing and unpublishing Firebase ML Kit models. """ +import datetime +import numbers import re import requests import six @@ -28,6 +30,12 @@ _MLKIT_ATTRIBUTE = '_mlkit' _MAX_PAGE_SIZE = 100 +_MODEL_ID_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') +_DISPLAY_NAME_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') +_TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') +_GCS_TFLITE_URI_PATTERN = re.compile(r'^gs://[a-z0-9_.-]{3,63}/.+') +_RESOURCE_NAME_PATTERN = re.compile( + r'^projects/(?P[^/]+)/models/(?P[A-Za-z0-9_-]{1,60})$') def _get_mlkit_service(app): @@ -47,7 +55,7 @@ def _get_mlkit_service(app): def get_model(model_id, app=None): mlkit_service = _get_mlkit_service(app) - return Model(mlkit_service.get_model(model_id)) + return Model.from_dict(mlkit_service.get_model(model_id)) def list_models(list_filter=None, page_size=None, page_token=None, app=None): @@ -62,14 +70,39 @@ def delete_model(model_id, app=None): class Model(object): - """A Firebase ML Kit Model object.""" - def __init__(self, data): - """Created from a data dictionary.""" - self._data = data + """A Firebase ML Kit Model object. + + Args: + display_name: The display name of your model - used to identify your model in code. + tags: Optional list of strings associated with your model. Can be used in list queries. + model_format: A subclass of ModelFormat. (e.g. TFLiteFormat) Specifies the model details. + """ + def __init__(self, display_name=None, tags=None, model_format=None): + self._data = {} + self._model_format = None + + if display_name is not None: + self.display_name = display_name + if tags is not None: + self.tags = tags + if model_format is not None: + self.model_format = model_format + + @classmethod + def from_dict(cls, data): + data_copy = dict(data) + tflite_format = None + tflite_format_data = data_copy.pop('tfliteModel', None) + if tflite_format_data: + tflite_format = TFLiteFormat.from_dict(tflite_format_data) + model = Model(model_format=tflite_format) + model._data = data_copy # pylint: disable=protected-access + return model def __eq__(self, other): if isinstance(other, self.__class__): - return self._data == other._data # pylint: disable=protected-access + # pylint: disable=protected-access + return self._data == other._data and self._model_format == other._model_format else: return False @@ -77,14 +110,182 @@ def __ne__(self, other): return not self.__eq__(other) @property - def name(self): - return self._data['name'] + def model_id(self): + if not self._data.get('name'): + return None + _, model_id = _validate_and_parse_name(self._data.get('name')) + return model_id @property def display_name(self): - return self._data['displayName'] + return self._data.get('displayName') + + @display_name.setter + def display_name(self, display_name): + self._data['displayName'] = _validate_display_name(display_name) + return self + + @property + def create_time(self): + """Returns the creation timestamp""" + seconds = self._data.get('createTime', {}).get('seconds') + if not isinstance(seconds, numbers.Number): + return None + + return datetime.datetime.fromtimestamp(float(seconds)) + + @property + def update_time(self): + """Returns the last update timestamp""" + seconds = self._data.get('updateTime', {}).get('seconds') + if not isinstance(seconds, numbers.Number): + return None - #TODO(ifielker): define the rest of the Model properties etc + return datetime.datetime.fromtimestamp(float(seconds)) + + @property + def validation_error(self): + return self._data.get('state', {}).get('validationError', {}).get('message') + + @property + def published(self): + return bool(self._data.get('state', {}).get('published')) + + @property + def etag(self): + return self._data.get('etag') + + @property + def model_hash(self): + return self._data.get('modelHash') + + @property + def tags(self): + return self._data.get('tags') + + @tags.setter + def tags(self, tags): + self._data['tags'] = _validate_tags(tags) + return self + + @property + def locked(self): + return bool(self._data.get('activeOperations') and + len(self._data.get('activeOperations')) > 0) + + @property + def model_format(self): + return self._model_format + + @model_format.setter + def model_format(self, model_format): + if model_format is not None: + _validate_model_format(model_format) + self._model_format = model_format #Can be None + return self + + def as_dict(self): + copy = dict(self._data) + if self._model_format: + copy.update(self._model_format.as_dict()) + return copy + + +class ModelFormat(object): + """Abstract base class representing a Model Format such as TFLite.""" + def as_dict(self): + raise NotImplementedError + + +class TFLiteFormat(ModelFormat): + """Model format representing a TFLite model. + + Args: + model_source: A TFLiteModelSource sub class. Specifies the details of the model source. + """ + def __init__(self, model_source=None): + self._data = {} + self._model_source = None + + if model_source is not None: + self.model_source = model_source + + @classmethod + def from_dict(cls, data): + data_copy = dict(data) + model_source = None + gcs_tflite_uri = data_copy.pop('gcsTfliteUri', None) + if gcs_tflite_uri: + model_source = TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri) + tflite_format = TFLiteFormat(model_source=model_source) + tflite_format._data = data_copy # pylint: disable=protected-access + return tflite_format + + + def __eq__(self, other): + if isinstance(other, self.__class__): + # pylint: disable=protected-access + return self._data == other._data and self._model_source == other._model_source + else: + return False + + def __ne__(self, other): + return not self.__eq__(other) + + @property + def model_source(self): + return self._model_source + + @model_source.setter + def model_source(self, model_source): + if model_source is not None: + if not isinstance(model_source, TFLiteModelSource): + raise TypeError('Model source must be a TFLiteModelSource object.') + self._model_source = model_source # Can be None + + @property + def size_bytes(self): + return self._data.get('sizeBytes') + + def as_dict(self): + copy = dict(self._data) + if self._model_source: + copy.update(self._model_source.as_dict()) + return {'tfliteModel': copy} + + +class TFLiteModelSource(object): + """Abstract base class representing a model source for TFLite format models.""" + def as_dict(self): + raise NotImplementedError + + +class TFLiteGCSModelSource(TFLiteModelSource): + """TFLite model source representing a tflite model file stored in GCS.""" + def __init__(self, gcs_tflite_uri): + self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access + else: + return False + + def __ne__(self, other): + return not self.__eq__(other) + + @property + def gcs_tflite_uri(self): + return self._gcs_tflite_uri + + @gcs_tflite_uri.setter + def gcs_tflite_uri(self, gcs_tflite_uri): + self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) + + def as_dict(self): + return {"gcsTfliteUri": self._gcs_tflite_uri} + + #TODO(ifielker): implement from_saved_model etc. class ListModelsPage(object): @@ -105,7 +306,7 @@ def __init__(self, list_models_func, list_filter, page_size, page_token): @property def models(self): """A list of Models from this page.""" - return [Model(model) for model in self._list_response.get('models', [])] + return [Model.from_dict(model) for model in self._list_response.get('models', [])] @property def list_filter(self): @@ -179,13 +380,48 @@ def __iter__(self): return self +def _validate_and_parse_name(name): + # The resource name is added automatically from API call responses. + # The only way it could be invalid is if someone tries to + # create a model from a dictionary manually and does it incorrectly. + matcher = _RESOURCE_NAME_PATTERN.match(name) + if not matcher: + raise ValueError('Model resource name format is invalid.') + return matcher.group('project_id'), matcher.group('model_id') + + def _validate_model_id(model_id): - if not isinstance(model_id, six.string_types): - raise TypeError('Model ID must be a string.') - if not re.match(r'^[A-Za-z0-9_-]{1,60}$', model_id): + if not _MODEL_ID_PATTERN.match(model_id): raise ValueError('Model ID format is invalid.') +def _validate_display_name(display_name): + if not _DISPLAY_NAME_PATTERN.match(display_name): + raise ValueError('Display name format is invalid.') + return display_name + + +def _validate_tags(tags): + if not isinstance(tags, list) or not \ + all(isinstance(tag, six.string_types) for tag in tags): + raise TypeError('Tags must be a list of strings.') + if not all(_TAG_PATTERN.match(tag) for tag in tags): + raise ValueError('Tag format is invalid.') + return tags + + +def _validate_gcs_tflite_uri(uri): + # GCS Bucket naming rules are complex. The regex is not comprehensive. + # See https://cloud.google.com/storage/docs/naming for full details. + if not _GCS_TFLITE_URI_PATTERN.match(uri): + raise ValueError('GCS TFLite URI format is invalid.') + return uri + +def _validate_model_format(model_format): + if not isinstance(model_format, ModelFormat): + raise TypeError('Model format must be a ModelFormat object.') + return model_format + def _validate_list_filter(list_filter): if list_filter is not None: if not isinstance(list_filter, six.string_types): diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index e14bd4371..c20982a2b 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -14,6 +14,7 @@ """Test cases for the firebase_admin.mlkit module.""" +import datetime import json import pytest @@ -27,6 +28,25 @@ PROJECT_ID = 'myProject1' PAGE_TOKEN = 'pageToken' NEXT_PAGE_TOKEN = 'nextPageToken' +CREATE_TIME_SECONDS = 1566426374 +CREATE_TIME_JSON = { + 'seconds': CREATE_TIME_SECONDS +} +CREATE_TIME_DATETIME = datetime.datetime.fromtimestamp(CREATE_TIME_SECONDS) + +UPDATE_TIME_SECONDS = 1566426678 +UPDATE_TIME_JSON = { + 'seconds': UPDATE_TIME_SECONDS +} +UPDATE_TIME_DATETIME = datetime.datetime.fromtimestamp(UPDATE_TIME_SECONDS) +ETAG = '33a64df551425fcc55e4d42a148795d9f25f89d4' +MODEL_HASH = '987987a98b98798d098098e09809fc0893897' +TAG_1 = 'Tag1' +TAG_2 = 'Tag2' +TAG_3 = 'Tag3' +TAGS = [TAG_1, TAG_2] +TAGS_2 = [TAG_1, TAG_3] + MODEL_ID_1 = 'modelId1' MODEL_NAME_1 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1) DISPLAY_NAME_1 = 'displayName1' @@ -34,7 +54,7 @@ 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_1 } -MODEL_1 = mlkit.Model(MODEL_JSON_1) +MODEL_1 = mlkit.Model.from_dict(MODEL_JSON_1) MODEL_ID_2 = 'modelId2' MODEL_NAME_2 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_2) @@ -43,7 +63,7 @@ 'name': MODEL_NAME_2, 'displayName': DISPLAY_NAME_2 } -MODEL_2 = mlkit.Model(MODEL_JSON_2) +MODEL_2 = mlkit.Model.from_dict(MODEL_JSON_2) MODEL_ID_3 = 'modelId3' MODEL_NAME_3 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_3) @@ -52,7 +72,69 @@ 'name': MODEL_NAME_3, 'displayName': DISPLAY_NAME_3 } -MODEL_3 = mlkit.Model(MODEL_JSON_3) +MODEL_3 = mlkit.Model.from_dict(MODEL_JSON_3) + +MODEL_STATE_PUBLISHED_JSON = { + 'published': True +} +VALIDATION_ERROR_CODE = 400 +VALIDATION_ERROR_MSG = 'No model format found for {0}.'.format(MODEL_ID_1) +MODEL_STATE_ERROR_JSON = { + 'validationError': { + 'code': VALIDATION_ERROR_CODE, + 'message': VALIDATION_ERROR_MSG, + } +} + +OPERATION_NOT_DONE_JSON_1 = { + 'name': 'operations/project/{0}/model/{1}/operation/123'.format(PROJECT_ID, MODEL_ID_1), + 'metadata': { + '@type': 'type.googleapis.com/google.firebase.ml.v1beta1.ModelOperationMetadata', + 'name': 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1), + 'basic_operation_status': 'BASIC_OPERATION_STATUS_UPLOADING' + } +} + +GCS_TFLITE_URI = 'gs://my_bucket/mymodel.tflite' +GCS_TFLITE_URI_JSON = {'gcsTfliteUri': GCS_TFLITE_URI} +GCS_TFLITE_MODEL_SOURCE = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI) +TFLITE_FORMAT_JSON = { + 'gcsTfliteUri': GCS_TFLITE_URI, + 'sizeBytes': '1234567' +} +TFLITE_FORMAT = mlkit.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON) + +GCS_TFLITE_URI_2 = 'gs://my_bucket/mymodel2.tflite' +GCS_TFLITE_URI_JSON_2 = {'gcsTfliteUri': GCS_TFLITE_URI_2} +GCS_TFLITE_MODEL_SOURCE_2 = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI_2) +TFLITE_FORMAT_JSON_2 = { + 'gcsTfliteUri': GCS_TFLITE_URI_2, + 'sizeBytes': '2345678' +} +TFLITE_FORMAT_2 = mlkit.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_2) + +FULL_MODEL_ERR_STATE_LRO_JSON = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_1, + 'createTime': CREATE_TIME_JSON, + 'updateTime': UPDATE_TIME_JSON, + 'state': MODEL_STATE_ERROR_JSON, + 'etag': ETAG, + 'modelHash': MODEL_HASH, + 'tags': TAGS, + 'activeOperations': [OPERATION_NOT_DONE_JSON_1], +} +FULL_MODEL_PUBLISHED_JSON = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_1, + 'createTime': CREATE_TIME_JSON, + 'updateTime': UPDATE_TIME_JSON, + 'state': MODEL_STATE_PUBLISHED_JSON, + 'etag': ETAG, + 'modelHash': MODEL_HASH, + 'tags': TAGS, + 'tfliteModel': TFLITE_FORMAT_JSON +} EMPTY_RESPONSE = json.dumps({}) DEFAULT_GET_RESPONSE = json.dumps(MODEL_JSON_1) @@ -93,29 +175,20 @@ ERROR_RESPONSE_BAD_REQUEST = json.dumps(ERROR_JSON_BAD_REQUEST) invalid_model_id_args = [ - ('', ValueError, 'Model ID format is invalid.'), - ('&_*#@:/?', ValueError, 'Model ID format is invalid.'), - (None, TypeError, 'Model ID must be a string.'), - (12345, TypeError, 'Model ID must be a string.'), + ('', ValueError), + ('&_*#@:/?', ValueError), + (None, TypeError), + (12345, TypeError), ] PAGE_SIZE_VALUE_ERROR_MSG = 'Page size must be a positive integer between ' \ '1 and {0}'.format(mlkit._MAX_PAGE_SIZE) -invalid_page_size_args = [ - ('abc', TypeError, 'Page size must be a number or None.'), - (4.2, TypeError, 'Page size must be a number or None.'), - (list(), TypeError, 'Page size must be a number or None.'), - (dict(), TypeError, 'Page size must be a number or None.'), - (True, TypeError, 'Page size must be a number or None.'), - (-1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), - (0, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), - (mlkit._MAX_PAGE_SIZE + 1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG) -] invalid_string_or_none_args = [0, -1, 4.2, 0x10, False, list(), dict()] -def check_error(err, err_type, msg): +def check_error(err, err_type, msg=None): assert isinstance(err, err_type) - assert str(err) == msg + if msg: + assert str(err) == msg def check_firebase_error(err, code, status, msg): @@ -138,6 +211,151 @@ def instrument_mlkit_service(app=None, status=200, payload=None): return recorder +class TestModel(object): + """Tests mlkit.Model class.""" + + def test_model_success_err_state_lro(self): + model = mlkit.Model.from_dict(FULL_MODEL_ERR_STATE_LRO_JSON) + assert model.model_id == MODEL_ID_1 + assert model.display_name == DISPLAY_NAME_1 + assert model.create_time == CREATE_TIME_DATETIME + assert model.update_time == UPDATE_TIME_DATETIME + assert model.validation_error == VALIDATION_ERROR_MSG + assert model.published is False + assert model.etag == ETAG + assert model.model_hash == MODEL_HASH + assert model.tags == TAGS + assert model.locked is True + assert model.model_format is None + assert model.as_dict() == FULL_MODEL_ERR_STATE_LRO_JSON + + def test_model_success_published(self): + model = mlkit.Model.from_dict(FULL_MODEL_PUBLISHED_JSON) + assert model.model_id == MODEL_ID_1 + assert model.display_name == DISPLAY_NAME_1 + assert model.create_time == CREATE_TIME_DATETIME + assert model.update_time == UPDATE_TIME_DATETIME + assert model.validation_error is None + assert model.published is True + assert model.etag == ETAG + assert model.model_hash == MODEL_HASH + assert model.tags == TAGS + assert model.locked is False + assert model.model_format == TFLITE_FORMAT + assert model.as_dict() == FULL_MODEL_PUBLISHED_JSON + + def test_model_keyword_based_creation_and_setters(self): + model = mlkit.Model(display_name=DISPLAY_NAME_1, tags=TAGS, model_format=TFLITE_FORMAT) + assert model.display_name == DISPLAY_NAME_1 + assert model.tags == TAGS + assert model.model_format == TFLITE_FORMAT + assert model.as_dict() == { + 'displayName': DISPLAY_NAME_1, + 'tags': TAGS, + 'tfliteModel': TFLITE_FORMAT_JSON + } + + model.display_name = DISPLAY_NAME_2 + model.tags = TAGS_2 + model.model_format = TFLITE_FORMAT_2 + assert model.as_dict() == { + 'displayName': DISPLAY_NAME_2, + 'tags': TAGS_2, + 'tfliteModel': TFLITE_FORMAT_JSON_2 + } + + def test_model_format_source_creation(self): + model_source = mlkit.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) + model_format = mlkit.TFLiteFormat(model_source=model_source) + model = mlkit.Model(display_name=DISPLAY_NAME_1, model_format=model_format) + assert model.as_dict() == { + 'displayName': DISPLAY_NAME_1, + 'tfliteModel': { + 'gcsTfliteUri': GCS_TFLITE_URI + } + } + + def test_model_source_setters(self): + model_source = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI) + model_source.gcs_tflite_uri = GCS_TFLITE_URI_2 + assert model_source.gcs_tflite_uri == GCS_TFLITE_URI_2 + assert model_source.as_dict() == GCS_TFLITE_URI_JSON_2 + + def test_model_format_setters(self): + model_format = mlkit.TFLiteFormat(model_source=GCS_TFLITE_MODEL_SOURCE) + model_format.model_source = GCS_TFLITE_MODEL_SOURCE_2 + assert model_format.model_source == GCS_TFLITE_MODEL_SOURCE_2 + assert model_format.as_dict() == { + 'tfliteModel': { + 'gcsTfliteUri': GCS_TFLITE_URI_2 + } + } + + @pytest.mark.parametrize('display_name, exc_type', [ + ('', ValueError), + ('&_*#@:/?', ValueError), + (12345, TypeError) + ]) + def test_model_display_name_validation_errors(self, display_name, exc_type): + with pytest.raises(exc_type) as err: + mlkit.Model(display_name=display_name) + check_error(err.value, exc_type) + + @pytest.mark.parametrize('tags, exc_type, error_message', [ + ('tag1', TypeError, 'Tags must be a list of strings.'), + (123, TypeError, 'Tags must be a list of strings.'), + (['tag1', 123, 'tag2'], TypeError, 'Tags must be a list of strings.'), + (['tag1', '@#$%^&'], ValueError, 'Tag format is invalid.'), + (['', 'tag2'], ValueError, 'Tag format is invalid.'), + (['sixty-one_characters_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx', + 'tag2'], ValueError, 'Tag format is invalid.') + ]) + def test_model_tags_validation_errors(self, tags, exc_type, error_message): + with pytest.raises(exc_type) as err: + mlkit.Model(tags=tags) + check_error(err.value, exc_type, error_message) + + @pytest.mark.parametrize('model_format', [ + 123, + "abc", + {}, + [], + True + ]) + def test_model_format_validation_errors(self, model_format): + with pytest.raises(TypeError) as err: + mlkit.Model(model_format=model_format) + check_error(err.value, TypeError, 'Model format must be a ModelFormat object.') + + @pytest.mark.parametrize('model_source', [ + 123, + "abc", + {}, + [], + True + ]) + def test_model_source_validation_errors(self, model_source): + with pytest.raises(TypeError) as err: + mlkit.TFLiteFormat(model_source=model_source) + check_error(err.value, TypeError, 'Model source must be a TFLiteModelSource object.') + + @pytest.mark.parametrize('uri, exc_type', [ + (123, TypeError), + ('abc', ValueError), + ('gs://NO_CAPITALS', ValueError), + ('gs://abc/', ValueError), + ('gs://aa/model.tflite', ValueError), + ('gs://@#$%/model.tflite', ValueError), + ('gs://invalid space/model.tflite', ValueError), + ('gs://sixty-four-characters_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx/model.tflite', + ValueError) + ]) + def test_gcs_tflite_source_validation_errors(self, uri, exc_type): + with pytest.raises(exc_type) as err: + mlkit.TFLiteGCSModelSource(gcs_tflite_uri=uri) + check_error(err.value, exc_type) + + class TestGetModel(object): """Tests mlkit.get_model.""" @classmethod @@ -160,14 +378,14 @@ def test_get_model(self): assert recorder[0].method == 'GET' assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1) assert model == MODEL_1 - assert model.name == MODEL_NAME_1 + assert model.model_id == MODEL_ID_1 assert model.display_name == DISPLAY_NAME_1 - @pytest.mark.parametrize('model_id, exc_type, error_message', invalid_model_id_args) - def test_get_model_validation_errors(self, model_id, exc_type, error_message): + @pytest.mark.parametrize('model_id, exc_type', invalid_model_id_args) + def test_get_model_validation_errors(self, model_id, exc_type): with pytest.raises(exc_type) as err: mlkit.get_model(model_id) - check_error(err.value, exc_type, error_message) + check_error(err.value, exc_type) def test_get_model_error(self): recorder = instrument_mlkit_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) @@ -190,6 +408,7 @@ def evaluate(): mlkit.get_model(MODEL_ID_1, app) testutils.run_without_project_id(evaluate) + class TestDeleteModel(object): """Tests mlkit.delete_model.""" @classmethod @@ -212,11 +431,11 @@ def test_delete_model(self): assert recorder[0].method == 'DELETE' assert recorder[0].url == TestDeleteModel._url(PROJECT_ID, MODEL_ID_1) - @pytest.mark.parametrize('model_id, exc_type, error_message', invalid_model_id_args) - def test_delete_model_validation_errors(self, model_id, exc_type, error_message): + @pytest.mark.parametrize('model_id, exc_type', invalid_model_id_args) + def test_delete_model_validation_errors(self, model_id, exc_type): with pytest.raises(exc_type) as err: mlkit.delete_model(model_id) - check_error(err.value, exc_type, error_message) + check_error(err.value, exc_type) def test_delete_model_error(self): recorder = instrument_mlkit_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) @@ -299,7 +518,16 @@ def test_list_models_list_filter_validation(self, list_filter): mlkit.list_models(list_filter=list_filter) check_error(err.value, TypeError, 'List filter must be a string or None.') - @pytest.mark.parametrize('page_size, exc_type, error_message', invalid_page_size_args) + @pytest.mark.parametrize('page_size, exc_type, error_message', [ + ('abc', TypeError, 'Page size must be a number or None.'), + (4.2, TypeError, 'Page size must be a number or None.'), + (list(), TypeError, 'Page size must be a number or None.'), + (dict(), TypeError, 'Page size must be a number or None.'), + (True, TypeError, 'Page size must be a number or None.'), + (-1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), + (0, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), + (mlkit._MAX_PAGE_SIZE + 1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG) + ]) def test_list_models_page_size_validation(self, page_size, exc_type, error_message): with pytest.raises(exc_type) as err: mlkit.list_models(page_size=page_size)