From 573e7cbb1af21e257c865643919a7ed06fa3a75c Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 27 Aug 2019 17:52:27 -0400 Subject: [PATCH 01/11] Implementation of Model, ModelFormat, ModelSource and subclasses --- firebase_admin/mlkit.py | 238 +++++++++++++++++++++++++++++++++++++++- tests/test_mlkit.py | 209 ++++++++++++++++++++++++++++++++++- 2 files changed, 440 insertions(+), 7 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index dd02ea8db..74ba8a50d 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -18,6 +18,8 @@ deleting, publishing and unpublishing Firebase ML Kit models. """ +import datetime +import numbers import re import requests import six @@ -63,9 +65,25 @@ def delete_model(model_id, app=None): class Model(object): """A Firebase ML Kit Model object.""" - def __init__(self, data): + def __init__(self, data=None, display_name=None, tags=None, model_format=None): """Created from a data dictionary.""" - self._data = data + if data is not None and isinstance(data, dict): + self._data = data + else: + self._data = {} + if display_name is not None: + _validate_display_name(display_name) + self._data['displayName'] = display_name + if tags is not None: + _validate_tags(tags) + self._data['tags'] = tags + if model_format is not None: + _validate_model_format(model_format) + if isinstance(model_format, TFLiteFormat): + self._data['tfliteModel'] = model_format.get_json() + else: + raise TypeError('Unsupported model format type.') + def __eq__(self, other): if isinstance(other, self.__class__): @@ -77,15 +95,181 @@ def __ne__(self, other): return not self.__eq__(other) @property - def name(self): - return self._data['name'] + def model_id(self): + if not self._data.get('name'): + return None + _, model_id = _validate_and_parse_name(self._data.get('name')) + return model_id @property def display_name(self): - return self._data['displayName'] + return self._data.get('displayName') + + @display_name.setter + def display_name(self, display_name): + _validate_display_name(display_name) + self._data['displayName'] = display_name + return self + + @property + def create_time(self): + if self._data.get('createTime') and \ + self._data.get('createTime').get('seconds') and \ + isinstance(self._data.get('createTime').get('seconds'), numbers.Number): + return datetime.datetime.fromtimestamp( + float(self._data.get('createTime').get('seconds'))) + return None + + @property + def update_time(self): + if self._data.get('updateTime') and \ + self._data.get('updateTime').get('seconds') and \ + isinstance(self._data.get('updateTime').get('seconds'), numbers.Number): + return datetime.datetime.fromtimestamp( + float(self._data.get('updateTime').get('seconds'))) + return None + + @property + def validation_error(self): + return self._data.get('state') and \ + self._data.get('state').get('validationError') and \ + self._data.get('state').get('validationError').get('message') + + @property + def published(self): + return bool(self._data.get('state') and + self._data.get('state').get('published')) + + @property + def etag(self): + return self._data.get('etag') + + @property + def model_hash(self): + return self._data.get('modelHash') + + @property + def tags(self): + return self._data.get('tags') + + @tags.setter + def tags(self, tags): + _validate_tags(tags) + self._data['tags'] = tags + return self + + @property + def locked(self): + return bool(self._data.get('activeOperations') and + len(self._data.get('activeOperations')) > 0) + + @property + def model_format(self): + if self._data.get('tfliteModel'): + return TFLiteFormat(self._data.get('tfliteModel')) + return None + + @model_format.setter + def model_format(self, model_format): + if not isinstance(model_format, TFLiteFormat): + raise TypeError('Unsupported model format type.') + self._data['tfliteModel'] = model_format.get_json() + return self + + def get_json(self): + return self._data + + +class ModelFormat(object): + """Abstract base class representing a Model Format such as TFLite.""" + def get_json(self): + raise NotImplementedError + + +class TFLiteFormat(ModelFormat): + """Model format representing a TFLite model.""" + def __init__(self, data=None, model_source=None): + if (data is not None) and isinstance(data, dict): + self._data = data + else: + self._data = {} + if model_source is not None: + # Check for correct base type + if not isinstance(model_source, TFLiteModelSource): + raise TypeError('Model source must be a ModelSource object.') + # Set based on specific sub type + if isinstance(model_source, TFLiteGCSModelSource): + self._data['gcsTfliteUri'] = model_source.get_json() + else: + raise TypeError('Unsupported model source type.') + + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self._data == other._data # pylint: disable=protected-access + else: + return False + + def __ne__(self, other): + return not self.__eq__(other) + + @property + def model_source(self): + if self._data.get('gcsTfliteUri'): + return TFLiteGCSModelSource(self._data.get('gcsTfliteUri')) + return None + + @model_source.setter + def model_source(self, model_source): + if model_source is not None: + if isinstance(model_source, TFLiteGCSModelSource): + self._data['gcsTfliteUri'] = model_source.get_json() + else: + raise TypeError('Unsupported model source type.') + + + @property + def size_bytes(self): + return self._data.get('sizeBytes') + + def get_json(self): + return self._data + + +class TFLiteModelSource(object): + """Abstract base class representing a model source for TFLite format models.""" + def get_json(self): + raise NotImplementedError + + +class TFLiteGCSModelSource(TFLiteModelSource): + """TFLite model source representing a tflite model file stored in GCS.""" + def __init__(self, gcs_tflite_uri): + _validate_gcs_tflite_uri(gcs_tflite_uri) + self._gcs_tflite_uri = gcs_tflite_uri - #TODO(ifielker): define the rest of the Model properties etc + def __eq__(self, other): + if isinstance(other, self.__class__): + return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access + else: + return False + + def __ne__(self, other): + return not self.__eq__(other) + + @property + def gcs_tflite_uri(self): + return self._gcs_tflite_uri + + @gcs_tflite_uri.setter + def gcs_tflite_uri(self, gcs_tflite_uri): + _validate_gcs_tflite_uri(gcs_tflite_uri) + self._gcs_tflite_uri = gcs_tflite_uri + def get_json(self): + return self._gcs_tflite_uri + + #TODO(ifielker): implement from_saved_model etc. class ListModelsPage(object): """Represents a page of models in a firebase project. @@ -179,6 +363,20 @@ def __iter__(self): return self +def _validate_and_parse_name(name): + # The resource name is added automatically from API call responses. + # The only way it could be invalid is if someone tries to + # create a model from a dictionary manually and does it incorrectly. + if not isinstance(name, six.string_types): + raise TypeError('Model resource name must be a string.') + matcher = re.match( + r'^projects/(?P[^/]+)/models/(?P[A-Za-z0-9_-]{1,60})$', + name) + if not matcher: + raise ValueError('Model resource name format is invalid.') + return matcher.group('project_id'), matcher.group('model_id') + + def _validate_model_id(model_id): if not isinstance(model_id, six.string_types): raise TypeError('Model ID must be a string.') @@ -186,6 +384,34 @@ def _validate_model_id(model_id): raise ValueError('Model ID format is invalid.') +def _validate_display_name(display_name): + if not isinstance(display_name, six.string_types): + raise TypeError('Display name must be a string.') + if not re.match(r'^[A-Za-z0-9_-]{1,60}$', display_name): + raise ValueError('Display name format is invalid.') + + +def _validate_tags(tags): + if not isinstance(tags, list) or not \ + all(isinstance(tag, six.string_types) for tag in tags): + raise TypeError('Tags must be a list of strings.') + if not all(re.match(r'^[A-Za-z0-9_-]{1,60}$', tag) for tag in tags): + raise ValueError('Tag format is invalid.') + + +def _validate_gcs_tflite_uri(uri): + if not isinstance(uri, six.string_types): + raise TypeError('Gcs TFLite URI must be a string.') + # GCS Bucket naming rules are complex. The regex is not comprehensive. + # See https://cloud.google.com/storage/docs/naming for full details. + if not re.match(r'^gs://[a-z0-9_.-]{3,63}/.+', uri): + raise ValueError('GCS TFLite URI format is invalid.') + +def _validate_model_format(model_format): + if model_format is not None: + if not isinstance(model_format, ModelFormat): + raise TypeError('Model format must be a ModelFormat object.') + def _validate_list_filter(list_filter): if list_filter is not None: if not isinstance(list_filter, six.string_types): diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index e14bd4371..f7e4a4cf3 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -14,6 +14,7 @@ """Test cases for the firebase_admin.mlkit module.""" +import datetime import json import pytest @@ -27,6 +28,25 @@ PROJECT_ID = 'myProject1' PAGE_TOKEN = 'pageToken' NEXT_PAGE_TOKEN = 'nextPageToken' +CREATE_TIME_SECONDS = 1566426374 +CREATE_TIME_JSON = { + 'seconds': CREATE_TIME_SECONDS +} +CREATE_TIME_DATETIME = datetime.datetime.fromtimestamp(CREATE_TIME_SECONDS) + +UPDATE_TIME_SECONDS = 1566426678 +UPDATE_TIME_JSON = { + 'seconds': UPDATE_TIME_SECONDS +} +UPDATE_TIME_DATETIME = datetime.datetime.fromtimestamp(UPDATE_TIME_SECONDS) +ETAG = '33a64df551425fcc55e4d42a148795d9f25f89d4' +MODEL_HASH = '987987a98b98798d098098e09809fc0893897' +TAG_1 = 'Tag1' +TAG_2 = 'Tag2' +TAG_3 = 'Tag3' +TAGS = [TAG_1, TAG_2] +TAGS_2 = [TAG_1, TAG_3] + MODEL_ID_1 = 'modelId1' MODEL_NAME_1 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1) DISPLAY_NAME_1 = 'displayName1' @@ -54,6 +74,69 @@ } MODEL_3 = mlkit.Model(MODEL_JSON_3) +MODEL_STATE_PUBLISHED_JSON = { + 'published': True +} +VALIDATION_ERROR_CODE = 400 +VALIDATION_ERROR_MSG = 'No model format found for {0}.'.format(MODEL_ID_1) +MODEL_STATE_ERROR_JSON = { + 'validationError': { + 'code': VALIDATION_ERROR_CODE, + 'message': VALIDATION_ERROR_MSG, + } +} + +OPERATION_NOT_DONE_JSON_1 = { + 'name': 'operations/project/{0}/model/{1}/operation/123'.format(PROJECT_ID, MODEL_ID_1), + 'metadata': { + '@type': 'type.googleapis.com/google.firebase.ml.v1beta1.ModelOperationMetadata', + 'name': 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1), + 'basic_operation_status': 'BASIC_OPERATION_STATUS_UPLOADING' + } +} + +GCS_TFLITE_URI = 'gs://my_bucket/mymodel.tflite' +GCS_TFLITE_MODEL_SOURCE = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI) +TFLITE_FORMAT_JSON = { + 'gcsTfliteUri': GCS_TFLITE_URI, + 'sizeBytes': '1234567' +} +TFLITE_FORMAT = mlkit.TFLiteFormat(TFLITE_FORMAT_JSON) + +GCS_TFLITE_URI_2 = 'gs://my_bucket/mymodel2.tflite' +GCS_TFLITE_MODEL_SOURCE_2 = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI_2) +TFLITE_FORMAT_JSON_2 = { + 'gcsTfliteUri': GCS_TFLITE_URI_2, + 'sizeBytes': '2345678' +} +TFLITE_FORMAT_2 = mlkit.TFLiteFormat(TFLITE_FORMAT_JSON_2) + +FULL_MODEL_ERR_STATE_LRO_JSON = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_1, + 'createTime': CREATE_TIME_JSON, + 'updateTime': UPDATE_TIME_JSON, + 'state': MODEL_STATE_ERROR_JSON, + 'etag': ETAG, + 'modelHash': MODEL_HASH, + 'tags': TAGS, + 'activeOperations': [OPERATION_NOT_DONE_JSON_1], +} +FULL_MODEL_ERR_STATE_LRO = mlkit.Model(FULL_MODEL_ERR_STATE_LRO_JSON) + +FULL_MODEL_PUBLISHED_JSON = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_1, + 'createTime': CREATE_TIME_JSON, + 'updateTime': UPDATE_TIME_JSON, + 'state': MODEL_STATE_PUBLISHED_JSON, + 'etag': ETAG, + 'modelHash': MODEL_HASH, + 'tags': TAGS, + 'tfliteModel': TFLITE_FORMAT_JSON +} +FULL_MODEL_PUBLISHED = mlkit.Model(FULL_MODEL_PUBLISHED_JSON) + EMPTY_RESPONSE = json.dumps({}) DEFAULT_GET_RESPONSE = json.dumps(MODEL_JSON_1) NO_MODELS_LIST_RESPONSE = json.dumps({}) @@ -92,6 +175,29 @@ } ERROR_RESPONSE_BAD_REQUEST = json.dumps(ERROR_JSON_BAD_REQUEST) +invalid_display_name_args = [ + ('', ValueError, 'Display name format is invalid.'), + ('&_*#@:/?', ValueError, 'Display name format is invalid.'), + (12345, TypeError, 'Display name must be a string.') +] +invalid_tags_args = [ + ('tag1', TypeError, 'Tags must be a list of strings.'), + (123, TypeError, 'Tags must be a list of strings.'), + (['tag1', 123, 'tag2'], TypeError, 'Tags must be a list of strings.'), + (['tag1', '@#$%^&'], ValueError, 'Tag format is invalid.'), + (['', 'tag2'], ValueError, 'Tag format is invalid.'), + (['sixty-one_characters_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx', + 'tag2'], ValueError, 'Tag format is invalid.') +] +invalid_model_format_args = [ + (123, 'Model format must be a ModelFormat object.'), + (mlkit.ModelFormat(), 'Unsupported model format type.') +] +invalid_model_source_args = [ + (123, 'Model source must be a ModelSource object.'), + (mlkit.TFLiteModelSource(), 'Unsupported model source type.'), + +] invalid_model_id_args = [ ('', ValueError, 'Model ID format is invalid.'), ('&_*#@:/?', ValueError, 'Model ID format is invalid.'), @@ -138,6 +244,106 @@ def instrument_mlkit_service(app=None, status=200, payload=None): return recorder +class TestModel(object): + """Tests mlkit.Model class.""" + + def test_model_success_err_state_lro(self): + assert FULL_MODEL_ERR_STATE_LRO.model_id == MODEL_ID_1 + assert FULL_MODEL_ERR_STATE_LRO.display_name == DISPLAY_NAME_1 + assert FULL_MODEL_ERR_STATE_LRO.create_time == CREATE_TIME_DATETIME + assert FULL_MODEL_ERR_STATE_LRO.update_time == UPDATE_TIME_DATETIME + assert FULL_MODEL_ERR_STATE_LRO.validation_error == VALIDATION_ERROR_MSG + assert FULL_MODEL_ERR_STATE_LRO.published is False + assert FULL_MODEL_ERR_STATE_LRO.etag == ETAG + assert FULL_MODEL_ERR_STATE_LRO.model_hash == MODEL_HASH + assert FULL_MODEL_ERR_STATE_LRO.tags == TAGS + assert FULL_MODEL_ERR_STATE_LRO.locked is True + assert FULL_MODEL_ERR_STATE_LRO.model_format is None + assert FULL_MODEL_ERR_STATE_LRO.get_json() == FULL_MODEL_ERR_STATE_LRO_JSON + + def test_model_success_published(self): + assert FULL_MODEL_PUBLISHED.model_id == MODEL_ID_1 + assert FULL_MODEL_PUBLISHED.display_name == DISPLAY_NAME_1 + assert FULL_MODEL_PUBLISHED.create_time == CREATE_TIME_DATETIME + assert FULL_MODEL_PUBLISHED.update_time == UPDATE_TIME_DATETIME + assert FULL_MODEL_PUBLISHED.validation_error is None + assert FULL_MODEL_PUBLISHED.published is True + assert FULL_MODEL_PUBLISHED.etag == ETAG + assert FULL_MODEL_PUBLISHED.model_hash == MODEL_HASH + assert FULL_MODEL_PUBLISHED.tags == TAGS + assert FULL_MODEL_PUBLISHED.locked is False + assert FULL_MODEL_PUBLISHED.model_format == TFLITE_FORMAT + assert FULL_MODEL_PUBLISHED.get_json() == FULL_MODEL_PUBLISHED_JSON + + def test_model_keyword_based_creation_and_setters(self): + model = mlkit.Model(display_name=DISPLAY_NAME_1, tags=TAGS, model_format=TFLITE_FORMAT) + assert model.display_name == DISPLAY_NAME_1 + assert model.tags == TAGS + assert model.model_format == TFLITE_FORMAT + assert model.get_json() == { + 'displayName': DISPLAY_NAME_1, + 'tags': TAGS, + 'tfliteModel': TFLITE_FORMAT_JSON + } + + model.display_name = DISPLAY_NAME_2 + model.tags = TAGS_2 + model.model_format = TFLITE_FORMAT_2 + assert model.get_json() == { + 'displayName': DISPLAY_NAME_2, + 'tags': TAGS_2, + 'tfliteModel': TFLITE_FORMAT_JSON_2 + } + + def test_model_format_source_creation(self): + model_source = mlkit.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) + model_format = mlkit.TFLiteFormat(model_source=model_source) + model = mlkit.Model(display_name=DISPLAY_NAME_1, model_format=model_format) + assert model.get_json() == { + 'displayName': DISPLAY_NAME_1, + 'tfliteModel': { + 'gcsTfliteUri': GCS_TFLITE_URI + } + } + + def test_model_source_setters(self): + model_source = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI) + model_source.gcs_tflite_uri = GCS_TFLITE_URI_2 + assert model_source.gcs_tflite_uri == GCS_TFLITE_URI_2 + assert model_source.get_json() == GCS_TFLITE_URI_2 + + def test_model_format_setters(self): + model_format = mlkit.TFLiteFormat(model_source=GCS_TFLITE_MODEL_SOURCE) + model_format.model_source = GCS_TFLITE_MODEL_SOURCE_2 + assert model_format.model_source == GCS_TFLITE_MODEL_SOURCE_2 + assert model_format.get_json() == { + 'gcsTfliteUri': GCS_TFLITE_URI_2 + } + + @pytest.mark.parametrize('display_name, exc_type, error_message', invalid_display_name_args) + def test_model_display_name_validation_errors(self, display_name, exc_type, error_message): + with pytest.raises(exc_type) as err: + mlkit.Model(display_name=display_name) + check_error(err.value, exc_type, error_message) + + @pytest.mark.parametrize('tags, exc_type, error_message', invalid_tags_args) + def test_model_tags_validation_errors(self, tags, exc_type, error_message): + with pytest.raises(exc_type) as err: + mlkit.Model(tags=tags) + check_error(err.value, exc_type, error_message) + + @pytest.mark.parametrize('model_format, error_message', invalid_model_format_args) + def test_model_format_validation_errors(self, model_format, error_message): + with pytest.raises(TypeError) as err: + mlkit.Model(model_format=model_format) + check_error(err.value, TypeError, error_message) + + @pytest.mark.parametrize('model_source, error_message', invalid_model_source_args) + def test_model_source_validation_errors(self, model_source, error_message): + with pytest.raises(TypeError) as err: + mlkit.TFLiteFormat(model_source=model_source) + check_error(err.value, TypeError, error_message) + class TestGetModel(object): """Tests mlkit.get_model.""" @classmethod @@ -160,7 +366,7 @@ def test_get_model(self): assert recorder[0].method == 'GET' assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1) assert model == MODEL_1 - assert model.name == MODEL_NAME_1 + assert model.model_id == MODEL_ID_1 assert model.display_name == DISPLAY_NAME_1 @pytest.mark.parametrize('model_id, exc_type, error_message', invalid_model_id_args) @@ -190,6 +396,7 @@ def evaluate(): mlkit.get_model(MODEL_ID_1, app) testutils.run_without_project_id(evaluate) + class TestDeleteModel(object): """Tests mlkit.delete_model.""" @classmethod From e67bceb6d042c9af1ea644c624fa105542b51fbe Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 27 Aug 2019 20:06:58 -0400 Subject: [PATCH 02/11] review fixes --- firebase_admin/mlkit.py | 82 ++++++++++++++++++++--------------------- tests/test_mlkit.py | 81 ++++++++++++++++++++++++---------------- 2 files changed, 89 insertions(+), 74 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 74ba8a50d..2fe827679 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -49,7 +49,7 @@ def _get_mlkit_service(app): def get_model(model_id, app=None): mlkit_service = _get_mlkit_service(app) - return Model(mlkit_service.get_model(model_id)) + return Model(**mlkit_service.get_model(model_id)) def list_models(list_filter=None, page_size=None, page_token=None, app=None): @@ -65,18 +65,12 @@ def delete_model(model_id, app=None): class Model(object): """A Firebase ML Kit Model object.""" - def __init__(self, data=None, display_name=None, tags=None, model_format=None): - """Created from a data dictionary.""" - if data is not None and isinstance(data, dict): - self._data = data - else: - self._data = {} + def __init__(self, display_name=None, tags=None, model_format=None, **kwargs): + self._data = kwargs if display_name is not None: - _validate_display_name(display_name) - self._data['displayName'] = display_name + self._data['displayName'] = _validate_display_name(display_name) if tags is not None: - _validate_tags(tags) - self._data['tags'] = tags + self._data['tags'] = _validate_tags(tags) if model_format is not None: _validate_model_format(model_format) if isinstance(model_format, TFLiteFormat): @@ -107,27 +101,36 @@ def display_name(self): @display_name.setter def display_name(self, display_name): - _validate_display_name(display_name) - self._data['displayName'] = display_name + self._data['displayName'] = _validate_display_name(display_name) return self @property def create_time(self): - if self._data.get('createTime') and \ - self._data.get('createTime').get('seconds') and \ - isinstance(self._data.get('createTime').get('seconds'), numbers.Number): - return datetime.datetime.fromtimestamp( - float(self._data.get('createTime').get('seconds'))) - return None + create_time = self._data.get('createTime') + if not create_time: + return None + + seconds = create_time.get('seconds') + if not seconds: + return None + if not isinstance(seconds, numbers.Number): + return None + + return datetime.datetime.fromtimestamp(float(seconds)) @property def update_time(self): - if self._data.get('updateTime') and \ - self._data.get('updateTime').get('seconds') and \ - isinstance(self._data.get('updateTime').get('seconds'), numbers.Number): - return datetime.datetime.fromtimestamp( - float(self._data.get('updateTime').get('seconds'))) - return None + update_time = self._data.get('updateTime') + if not update_time: + return None + + seconds = update_time.get('seconds') + if not seconds: + return None + if not isinstance(seconds, numbers.Number): + return None + + return datetime.datetime.fromtimestamp(float(seconds)) @property def validation_error(self): @@ -154,8 +157,7 @@ def tags(self): @tags.setter def tags(self, tags): - _validate_tags(tags) - self._data['tags'] = tags + self._data['tags'] = _validate_tags(tags) return self @property @@ -166,7 +168,7 @@ def locked(self): @property def model_format(self): if self._data.get('tfliteModel'): - return TFLiteFormat(self._data.get('tfliteModel')) + return TFLiteFormat(**self._data.get('tfliteModel')) return None @model_format.setter @@ -188,11 +190,8 @@ def get_json(self): class TFLiteFormat(ModelFormat): """Model format representing a TFLite model.""" - def __init__(self, data=None, model_source=None): - if (data is not None) and isinstance(data, dict): - self._data = data - else: - self._data = {} + def __init__(self, model_source=None, **kwargs): + self._data = kwargs if model_source is not None: # Check for correct base type if not isinstance(model_source, TFLiteModelSource): @@ -203,7 +202,6 @@ def __init__(self, data=None, model_source=None): else: raise TypeError('Unsupported model source type.') - def __eq__(self, other): if isinstance(other, self.__class__): return self._data == other._data # pylint: disable=protected-access @@ -227,7 +225,6 @@ def model_source(self, model_source): else: raise TypeError('Unsupported model source type.') - @property def size_bytes(self): return self._data.get('sizeBytes') @@ -245,8 +242,7 @@ def get_json(self): class TFLiteGCSModelSource(TFLiteModelSource): """TFLite model source representing a tflite model file stored in GCS.""" def __init__(self, gcs_tflite_uri): - _validate_gcs_tflite_uri(gcs_tflite_uri) - self._gcs_tflite_uri = gcs_tflite_uri + self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) def __eq__(self, other): if isinstance(other, self.__class__): @@ -263,14 +259,14 @@ def gcs_tflite_uri(self): @gcs_tflite_uri.setter def gcs_tflite_uri(self, gcs_tflite_uri): - _validate_gcs_tflite_uri(gcs_tflite_uri) - self._gcs_tflite_uri = gcs_tflite_uri + self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) def get_json(self): return self._gcs_tflite_uri #TODO(ifielker): implement from_saved_model etc. + class ListModelsPage(object): """Represents a page of models in a firebase project. @@ -289,7 +285,7 @@ def __init__(self, list_models_func, list_filter, page_size, page_token): @property def models(self): """A list of Models from this page.""" - return [Model(model) for model in self._list_response.get('models', [])] + return [Model(**model) for model in self._list_response.get('models', [])] @property def list_filter(self): @@ -389,6 +385,7 @@ def _validate_display_name(display_name): raise TypeError('Display name must be a string.') if not re.match(r'^[A-Za-z0-9_-]{1,60}$', display_name): raise ValueError('Display name format is invalid.') + return display_name def _validate_tags(tags): @@ -397,20 +394,21 @@ def _validate_tags(tags): raise TypeError('Tags must be a list of strings.') if not all(re.match(r'^[A-Za-z0-9_-]{1,60}$', tag) for tag in tags): raise ValueError('Tag format is invalid.') + return tags def _validate_gcs_tflite_uri(uri): - if not isinstance(uri, six.string_types): - raise TypeError('Gcs TFLite URI must be a string.') # GCS Bucket naming rules are complex. The regex is not comprehensive. # See https://cloud.google.com/storage/docs/naming for full details. if not re.match(r'^gs://[a-z0-9_.-]{3,63}/.+', uri): raise ValueError('GCS TFLite URI format is invalid.') + return uri def _validate_model_format(model_format): if model_format is not None: if not isinstance(model_format, ModelFormat): raise TypeError('Model format must be a ModelFormat object.') + return model_format def _validate_list_filter(list_filter): if list_filter is not None: diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index f7e4a4cf3..2eaf88f83 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -54,7 +54,7 @@ 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_1 } -MODEL_1 = mlkit.Model(MODEL_JSON_1) +MODEL_1 = mlkit.Model(**MODEL_JSON_1) MODEL_ID_2 = 'modelId2' MODEL_NAME_2 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_2) @@ -63,7 +63,7 @@ 'name': MODEL_NAME_2, 'displayName': DISPLAY_NAME_2 } -MODEL_2 = mlkit.Model(MODEL_JSON_2) +MODEL_2 = mlkit.Model(**MODEL_JSON_2) MODEL_ID_3 = 'modelId3' MODEL_NAME_3 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_3) @@ -72,7 +72,7 @@ 'name': MODEL_NAME_3, 'displayName': DISPLAY_NAME_3 } -MODEL_3 = mlkit.Model(MODEL_JSON_3) +MODEL_3 = mlkit.Model(**MODEL_JSON_3) MODEL_STATE_PUBLISHED_JSON = { 'published': True @@ -101,7 +101,7 @@ 'gcsTfliteUri': GCS_TFLITE_URI, 'sizeBytes': '1234567' } -TFLITE_FORMAT = mlkit.TFLiteFormat(TFLITE_FORMAT_JSON) +TFLITE_FORMAT = mlkit.TFLiteFormat(**TFLITE_FORMAT_JSON) GCS_TFLITE_URI_2 = 'gs://my_bucket/mymodel2.tflite' GCS_TFLITE_MODEL_SOURCE_2 = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI_2) @@ -109,7 +109,7 @@ 'gcsTfliteUri': GCS_TFLITE_URI_2, 'sizeBytes': '2345678' } -TFLITE_FORMAT_2 = mlkit.TFLiteFormat(TFLITE_FORMAT_JSON_2) +TFLITE_FORMAT_2 = mlkit.TFLiteFormat(**TFLITE_FORMAT_JSON_2) FULL_MODEL_ERR_STATE_LRO_JSON = { 'name': MODEL_NAME_1, @@ -122,8 +122,6 @@ 'tags': TAGS, 'activeOperations': [OPERATION_NOT_DONE_JSON_1], } -FULL_MODEL_ERR_STATE_LRO = mlkit.Model(FULL_MODEL_ERR_STATE_LRO_JSON) - FULL_MODEL_PUBLISHED_JSON = { 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_1, @@ -135,7 +133,6 @@ 'tags': TAGS, 'tfliteModel': TFLITE_FORMAT_JSON } -FULL_MODEL_PUBLISHED = mlkit.Model(FULL_MODEL_PUBLISHED_JSON) EMPTY_RESPONSE = json.dumps({}) DEFAULT_GET_RESPONSE = json.dumps(MODEL_JSON_1) @@ -198,6 +195,18 @@ (mlkit.TFLiteModelSource(), 'Unsupported model source type.'), ] +GCS_TFLITE_VALUE_ERR_MSG = 'GCS TFLite URI format is invalid.' +invalid_gcs_tflite_uri_args = [ + (123, TypeError, 'expected string or buffer'), + ('abc', ValueError, GCS_TFLITE_VALUE_ERR_MSG), + ('gs://NO_CAPITALS', ValueError, GCS_TFLITE_VALUE_ERR_MSG), + ('gs://abc/', ValueError, GCS_TFLITE_VALUE_ERR_MSG), + ('gs://aa/model.tflite', ValueError, GCS_TFLITE_VALUE_ERR_MSG), + ('gs://@#$%/model.tflite', ValueError, GCS_TFLITE_VALUE_ERR_MSG), + ('gs://invalid space/model.tflite', ValueError, GCS_TFLITE_VALUE_ERR_MSG), + ('gs://sixty-four-characters_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx/model.tflite', + ValueError, GCS_TFLITE_VALUE_ERR_MSG) +] invalid_model_id_args = [ ('', ValueError, 'Model ID format is invalid.'), ('&_*#@:/?', ValueError, 'Model ID format is invalid.'), @@ -248,32 +257,34 @@ class TestModel(object): """Tests mlkit.Model class.""" def test_model_success_err_state_lro(self): - assert FULL_MODEL_ERR_STATE_LRO.model_id == MODEL_ID_1 - assert FULL_MODEL_ERR_STATE_LRO.display_name == DISPLAY_NAME_1 - assert FULL_MODEL_ERR_STATE_LRO.create_time == CREATE_TIME_DATETIME - assert FULL_MODEL_ERR_STATE_LRO.update_time == UPDATE_TIME_DATETIME - assert FULL_MODEL_ERR_STATE_LRO.validation_error == VALIDATION_ERROR_MSG - assert FULL_MODEL_ERR_STATE_LRO.published is False - assert FULL_MODEL_ERR_STATE_LRO.etag == ETAG - assert FULL_MODEL_ERR_STATE_LRO.model_hash == MODEL_HASH - assert FULL_MODEL_ERR_STATE_LRO.tags == TAGS - assert FULL_MODEL_ERR_STATE_LRO.locked is True - assert FULL_MODEL_ERR_STATE_LRO.model_format is None - assert FULL_MODEL_ERR_STATE_LRO.get_json() == FULL_MODEL_ERR_STATE_LRO_JSON + model = mlkit.Model(**FULL_MODEL_ERR_STATE_LRO_JSON) + assert model.model_id == MODEL_ID_1 + assert model.display_name == DISPLAY_NAME_1 + assert model.create_time == CREATE_TIME_DATETIME + assert model.update_time == UPDATE_TIME_DATETIME + assert model.validation_error == VALIDATION_ERROR_MSG + assert model.published is False + assert model.etag == ETAG + assert model.model_hash == MODEL_HASH + assert model.tags == TAGS + assert model.locked is True + assert model.model_format is None + assert model.get_json() == FULL_MODEL_ERR_STATE_LRO_JSON def test_model_success_published(self): - assert FULL_MODEL_PUBLISHED.model_id == MODEL_ID_1 - assert FULL_MODEL_PUBLISHED.display_name == DISPLAY_NAME_1 - assert FULL_MODEL_PUBLISHED.create_time == CREATE_TIME_DATETIME - assert FULL_MODEL_PUBLISHED.update_time == UPDATE_TIME_DATETIME - assert FULL_MODEL_PUBLISHED.validation_error is None - assert FULL_MODEL_PUBLISHED.published is True - assert FULL_MODEL_PUBLISHED.etag == ETAG - assert FULL_MODEL_PUBLISHED.model_hash == MODEL_HASH - assert FULL_MODEL_PUBLISHED.tags == TAGS - assert FULL_MODEL_PUBLISHED.locked is False - assert FULL_MODEL_PUBLISHED.model_format == TFLITE_FORMAT - assert FULL_MODEL_PUBLISHED.get_json() == FULL_MODEL_PUBLISHED_JSON + model = mlkit.Model(**FULL_MODEL_PUBLISHED_JSON) + assert model.model_id == MODEL_ID_1 + assert model.display_name == DISPLAY_NAME_1 + assert model.create_time == CREATE_TIME_DATETIME + assert model.update_time == UPDATE_TIME_DATETIME + assert model.validation_error is None + assert model.published is True + assert model.etag == ETAG + assert model.model_hash == MODEL_HASH + assert model.tags == TAGS + assert model.locked is False + assert model.model_format == TFLITE_FORMAT + assert model.get_json() == FULL_MODEL_PUBLISHED_JSON def test_model_keyword_based_creation_and_setters(self): model = mlkit.Model(display_name=DISPLAY_NAME_1, tags=TAGS, model_format=TFLITE_FORMAT) @@ -344,6 +355,12 @@ def test_model_source_validation_errors(self, model_source, error_message): mlkit.TFLiteFormat(model_source=model_source) check_error(err.value, TypeError, error_message) + @pytest.mark.parametrize('uri, exc_type, error_message', invalid_gcs_tflite_uri_args) + def test_gcs_tflite_source_validation_errors(self, uri, exc_type, error_message): + with pytest.raises(exc_type) as err: + mlkit.TFLiteGCSModelSource(gcs_tflite_uri=uri) + check_error(err.value, exc_type, error_message) + class TestGetModel(object): """Tests mlkit.get_model.""" @classmethod From 1f018fe608ededf69c3bfa1ec6fe4c12fc6b9add Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 27 Aug 2019 20:11:44 -0400 Subject: [PATCH 03/11] more review fixes --- firebase_admin/mlkit.py | 2 -- tests/test_mlkit.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 2fe827679..459a6da34 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -381,8 +381,6 @@ def _validate_model_id(model_id): def _validate_display_name(display_name): - if not isinstance(display_name, six.string_types): - raise TypeError('Display name must be a string.') if not re.match(r'^[A-Za-z0-9_-]{1,60}$', display_name): raise ValueError('Display name format is invalid.') return display_name diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 2eaf88f83..8ca6d39ef 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -175,7 +175,7 @@ invalid_display_name_args = [ ('', ValueError, 'Display name format is invalid.'), ('&_*#@:/?', ValueError, 'Display name format is invalid.'), - (12345, TypeError, 'Display name must be a string.') + (12345, TypeError, 'expected string or buffer') ] invalid_tags_args = [ ('tag1', TypeError, 'Tags must be a list of strings.'), From dfe0a37b979d696e8a28b6506dff37e8327c3676 Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 27 Aug 2019 20:37:58 -0400 Subject: [PATCH 04/11] review fixes 2 --- firebase_admin/mlkit.py | 15 ++++++++------- tests/test_mlkit.py | 4 ++-- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 459a6da34..a2aeb6729 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -30,6 +30,10 @@ _MLKIT_ATTRIBUTE = '_mlkit' _MAX_PAGE_SIZE = 100 +# These are coincidentally the same. They are not related. +_MODEL_ID_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') +_DISPLAY_NAME_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') +_TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') def _get_mlkit_service(app): @@ -78,7 +82,6 @@ def __init__(self, display_name=None, tags=None, model_format=None, **kwargs): else: raise TypeError('Unsupported model format type.') - def __eq__(self, other): if isinstance(other, self.__class__): return self._data == other._data # pylint: disable=protected-access @@ -374,14 +377,12 @@ def _validate_and_parse_name(name): def _validate_model_id(model_id): - if not isinstance(model_id, six.string_types): - raise TypeError('Model ID must be a string.') - if not re.match(r'^[A-Za-z0-9_-]{1,60}$', model_id): + if not _MODEL_ID_PATTERN.match(model_id): raise ValueError('Model ID format is invalid.') def _validate_display_name(display_name): - if not re.match(r'^[A-Za-z0-9_-]{1,60}$', display_name): + if not _DISPLAY_NAME_PATTERN.match(display_name): raise ValueError('Display name format is invalid.') return display_name @@ -390,7 +391,7 @@ def _validate_tags(tags): if not isinstance(tags, list) or not \ all(isinstance(tag, six.string_types) for tag in tags): raise TypeError('Tags must be a list of strings.') - if not all(re.match(r'^[A-Za-z0-9_-]{1,60}$', tag) for tag in tags): + if not all(_TAG_PATTERN.match(tag) for tag in tags): raise ValueError('Tag format is invalid.') return tags @@ -403,7 +404,7 @@ def _validate_gcs_tflite_uri(uri): return uri def _validate_model_format(model_format): - if model_format is not None: + if model_format: if not isinstance(model_format, ModelFormat): raise TypeError('Model format must be a ModelFormat object.') return model_format diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 8ca6d39ef..b687752df 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -210,8 +210,8 @@ invalid_model_id_args = [ ('', ValueError, 'Model ID format is invalid.'), ('&_*#@:/?', ValueError, 'Model ID format is invalid.'), - (None, TypeError, 'Model ID must be a string.'), - (12345, TypeError, 'Model ID must be a string.'), + (None, TypeError, 'expected string or buffer'), + (12345, TypeError, 'expected string or buffer'), ] PAGE_SIZE_VALUE_ERROR_MSG = 'Page size must be a positive integer between ' \ '1 and {0}'.format(mlkit._MAX_PAGE_SIZE) From 7704c44d1d93c28794c93a0673115ab7297dbcc3 Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 27 Aug 2019 20:46:04 -0400 Subject: [PATCH 05/11] review fixes 3 --- firebase_admin/mlkit.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index a2aeb6729..457cb9471 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -30,10 +30,12 @@ _MLKIT_ATTRIBUTE = '_mlkit' _MAX_PAGE_SIZE = 100 -# These are coincidentally the same. They are not related. _MODEL_ID_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') _DISPLAY_NAME_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') _TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') +_GCS_TFLITE_URI_PATTERN = re.compile(r'^gs://[a-z0-9_.-]{3,63}/.+') +_RESOURCE_NAME_PATTERN = re.compile( + r'^projects/(?P[^/]+)/models/(?P[A-Za-z0-9_-]{1,60})$') def _get_mlkit_service(app): @@ -366,11 +368,7 @@ def _validate_and_parse_name(name): # The resource name is added automatically from API call responses. # The only way it could be invalid is if someone tries to # create a model from a dictionary manually and does it incorrectly. - if not isinstance(name, six.string_types): - raise TypeError('Model resource name must be a string.') - matcher = re.match( - r'^projects/(?P[^/]+)/models/(?P[A-Za-z0-9_-]{1,60})$', - name) + matcher = _RESOURCE_NAME_PATTERN.match(name) if not matcher: raise ValueError('Model resource name format is invalid.') return matcher.group('project_id'), matcher.group('model_id') @@ -399,7 +397,7 @@ def _validate_tags(tags): def _validate_gcs_tflite_uri(uri): # GCS Bucket naming rules are complex. The regex is not comprehensive. # See https://cloud.google.com/storage/docs/naming for full details. - if not re.match(r'^gs://[a-z0-9_.-]{3,63}/.+', uri): + if not _GCS_TFLITE_URI_PATTERN.match(uri): raise ValueError('GCS TFLite URI format is invalid.') return uri From 8381ac5daa90bf55e73256466fe5b75b0513f081 Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 27 Aug 2019 20:55:41 -0400 Subject: [PATCH 06/11] review fixes 4 --- firebase_admin/mlkit.py | 18 +++++++++--------- tests/test_mlkit.py | 14 +++++++------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 457cb9471..b1e2810f8 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -80,7 +80,7 @@ def __init__(self, display_name=None, tags=None, model_format=None, **kwargs): if model_format is not None: _validate_model_format(model_format) if isinstance(model_format, TFLiteFormat): - self._data['tfliteModel'] = model_format.get_json() + self._data['tfliteModel'] = model_format.as_dict() else: raise TypeError('Unsupported model format type.') @@ -180,16 +180,16 @@ def model_format(self): def model_format(self, model_format): if not isinstance(model_format, TFLiteFormat): raise TypeError('Unsupported model format type.') - self._data['tfliteModel'] = model_format.get_json() + self._data['tfliteModel'] = model_format.as_dict() return self - def get_json(self): + def as_dict(self): return self._data class ModelFormat(object): """Abstract base class representing a Model Format such as TFLite.""" - def get_json(self): + def as_dict(self): raise NotImplementedError @@ -203,7 +203,7 @@ def __init__(self, model_source=None, **kwargs): raise TypeError('Model source must be a ModelSource object.') # Set based on specific sub type if isinstance(model_source, TFLiteGCSModelSource): - self._data['gcsTfliteUri'] = model_source.get_json() + self._data['gcsTfliteUri'] = model_source.as_dict() else: raise TypeError('Unsupported model source type.') @@ -226,7 +226,7 @@ def model_source(self): def model_source(self, model_source): if model_source is not None: if isinstance(model_source, TFLiteGCSModelSource): - self._data['gcsTfliteUri'] = model_source.get_json() + self._data['gcsTfliteUri'] = model_source.as_dict() else: raise TypeError('Unsupported model source type.') @@ -234,13 +234,13 @@ def model_source(self, model_source): def size_bytes(self): return self._data.get('sizeBytes') - def get_json(self): + def as_dict(self): return self._data class TFLiteModelSource(object): """Abstract base class representing a model source for TFLite format models.""" - def get_json(self): + def as_dict(self): raise NotImplementedError @@ -266,7 +266,7 @@ def gcs_tflite_uri(self): def gcs_tflite_uri(self, gcs_tflite_uri): self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) - def get_json(self): + def as_dict(self): return self._gcs_tflite_uri #TODO(ifielker): implement from_saved_model etc. diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index b687752df..6680feed9 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -269,7 +269,7 @@ def test_model_success_err_state_lro(self): assert model.tags == TAGS assert model.locked is True assert model.model_format is None - assert model.get_json() == FULL_MODEL_ERR_STATE_LRO_JSON + assert model.as_dict() == FULL_MODEL_ERR_STATE_LRO_JSON def test_model_success_published(self): model = mlkit.Model(**FULL_MODEL_PUBLISHED_JSON) @@ -284,14 +284,14 @@ def test_model_success_published(self): assert model.tags == TAGS assert model.locked is False assert model.model_format == TFLITE_FORMAT - assert model.get_json() == FULL_MODEL_PUBLISHED_JSON + assert model.as_dict() == FULL_MODEL_PUBLISHED_JSON def test_model_keyword_based_creation_and_setters(self): model = mlkit.Model(display_name=DISPLAY_NAME_1, tags=TAGS, model_format=TFLITE_FORMAT) assert model.display_name == DISPLAY_NAME_1 assert model.tags == TAGS assert model.model_format == TFLITE_FORMAT - assert model.get_json() == { + assert model.as_dict() == { 'displayName': DISPLAY_NAME_1, 'tags': TAGS, 'tfliteModel': TFLITE_FORMAT_JSON @@ -300,7 +300,7 @@ def test_model_keyword_based_creation_and_setters(self): model.display_name = DISPLAY_NAME_2 model.tags = TAGS_2 model.model_format = TFLITE_FORMAT_2 - assert model.get_json() == { + assert model.as_dict() == { 'displayName': DISPLAY_NAME_2, 'tags': TAGS_2, 'tfliteModel': TFLITE_FORMAT_JSON_2 @@ -310,7 +310,7 @@ def test_model_format_source_creation(self): model_source = mlkit.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) model_format = mlkit.TFLiteFormat(model_source=model_source) model = mlkit.Model(display_name=DISPLAY_NAME_1, model_format=model_format) - assert model.get_json() == { + assert model.as_dict() == { 'displayName': DISPLAY_NAME_1, 'tfliteModel': { 'gcsTfliteUri': GCS_TFLITE_URI @@ -321,13 +321,13 @@ def test_model_source_setters(self): model_source = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI) model_source.gcs_tflite_uri = GCS_TFLITE_URI_2 assert model_source.gcs_tflite_uri == GCS_TFLITE_URI_2 - assert model_source.get_json() == GCS_TFLITE_URI_2 + assert model_source.as_dict() == GCS_TFLITE_URI_2 def test_model_format_setters(self): model_format = mlkit.TFLiteFormat(model_source=GCS_TFLITE_MODEL_SOURCE) model_format.model_source = GCS_TFLITE_MODEL_SOURCE_2 assert model_format.model_source == GCS_TFLITE_MODEL_SOURCE_2 - assert model_format.get_json() == { + assert model_format.as_dict() == { 'gcsTfliteUri': GCS_TFLITE_URI_2 } From a2e75447d587a502567a7c859fab49507b002b8c Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 27 Aug 2019 21:19:02 -0400 Subject: [PATCH 07/11] review fixes 5 --- tests/test_mlkit.py | 60 ++++++++++++++++++++++----------------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 6680feed9..c4d238024 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -173,9 +173,9 @@ ERROR_RESPONSE_BAD_REQUEST = json.dumps(ERROR_JSON_BAD_REQUEST) invalid_display_name_args = [ - ('', ValueError, 'Display name format is invalid.'), - ('&_*#@:/?', ValueError, 'Display name format is invalid.'), - (12345, TypeError, 'expected string or buffer') + ('', ValueError), + ('&_*#@:/?', ValueError), + (12345, TypeError) ] invalid_tags_args = [ ('tag1', TypeError, 'Tags must be a list of strings.'), @@ -195,23 +195,22 @@ (mlkit.TFLiteModelSource(), 'Unsupported model source type.'), ] -GCS_TFLITE_VALUE_ERR_MSG = 'GCS TFLite URI format is invalid.' invalid_gcs_tflite_uri_args = [ - (123, TypeError, 'expected string or buffer'), - ('abc', ValueError, GCS_TFLITE_VALUE_ERR_MSG), - ('gs://NO_CAPITALS', ValueError, GCS_TFLITE_VALUE_ERR_MSG), - ('gs://abc/', ValueError, GCS_TFLITE_VALUE_ERR_MSG), - ('gs://aa/model.tflite', ValueError, GCS_TFLITE_VALUE_ERR_MSG), - ('gs://@#$%/model.tflite', ValueError, GCS_TFLITE_VALUE_ERR_MSG), - ('gs://invalid space/model.tflite', ValueError, GCS_TFLITE_VALUE_ERR_MSG), + (123, TypeError), + ('abc', ValueError), + ('gs://NO_CAPITALS', ValueError), + ('gs://abc/', ValueError), + ('gs://aa/model.tflite', ValueError), + ('gs://@#$%/model.tflite', ValueError), + ('gs://invalid space/model.tflite', ValueError), ('gs://sixty-four-characters_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx/model.tflite', - ValueError, GCS_TFLITE_VALUE_ERR_MSG) + ValueError) ] invalid_model_id_args = [ - ('', ValueError, 'Model ID format is invalid.'), - ('&_*#@:/?', ValueError, 'Model ID format is invalid.'), - (None, TypeError, 'expected string or buffer'), - (12345, TypeError, 'expected string or buffer'), + ('', ValueError), + ('&_*#@:/?', ValueError), + (None, TypeError), + (12345, TypeError), ] PAGE_SIZE_VALUE_ERROR_MSG = 'Page size must be a positive integer between ' \ '1 and {0}'.format(mlkit._MAX_PAGE_SIZE) @@ -228,9 +227,10 @@ invalid_string_or_none_args = [0, -1, 4.2, 0x10, False, list(), dict()] -def check_error(err, err_type, msg): +def check_error(err, err_type, msg=None): assert isinstance(err, err_type) - assert str(err) == msg + if msg: + assert str(err) == msg def check_firebase_error(err, code, status, msg): @@ -331,11 +331,11 @@ def test_model_format_setters(self): 'gcsTfliteUri': GCS_TFLITE_URI_2 } - @pytest.mark.parametrize('display_name, exc_type, error_message', invalid_display_name_args) - def test_model_display_name_validation_errors(self, display_name, exc_type, error_message): + @pytest.mark.parametrize('display_name, exc_type', invalid_display_name_args) + def test_model_display_name_validation_errors(self, display_name, exc_type): with pytest.raises(exc_type) as err: mlkit.Model(display_name=display_name) - check_error(err.value, exc_type, error_message) + check_error(err.value, exc_type) @pytest.mark.parametrize('tags, exc_type, error_message', invalid_tags_args) def test_model_tags_validation_errors(self, tags, exc_type, error_message): @@ -355,11 +355,11 @@ def test_model_source_validation_errors(self, model_source, error_message): mlkit.TFLiteFormat(model_source=model_source) check_error(err.value, TypeError, error_message) - @pytest.mark.parametrize('uri, exc_type, error_message', invalid_gcs_tflite_uri_args) - def test_gcs_tflite_source_validation_errors(self, uri, exc_type, error_message): + @pytest.mark.parametrize('uri, exc_type', invalid_gcs_tflite_uri_args) + def test_gcs_tflite_source_validation_errors(self, uri, exc_type): with pytest.raises(exc_type) as err: mlkit.TFLiteGCSModelSource(gcs_tflite_uri=uri) - check_error(err.value, exc_type, error_message) + check_error(err.value, exc_type) class TestGetModel(object): """Tests mlkit.get_model.""" @@ -386,11 +386,11 @@ def test_get_model(self): assert model.model_id == MODEL_ID_1 assert model.display_name == DISPLAY_NAME_1 - @pytest.mark.parametrize('model_id, exc_type, error_message', invalid_model_id_args) - def test_get_model_validation_errors(self, model_id, exc_type, error_message): + @pytest.mark.parametrize('model_id, exc_type', invalid_model_id_args) + def test_get_model_validation_errors(self, model_id, exc_type): with pytest.raises(exc_type) as err: mlkit.get_model(model_id) - check_error(err.value, exc_type, error_message) + check_error(err.value, exc_type) def test_get_model_error(self): recorder = instrument_mlkit_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) @@ -436,11 +436,11 @@ def test_delete_model(self): assert recorder[0].method == 'DELETE' assert recorder[0].url == TestDeleteModel._url(PROJECT_ID, MODEL_ID_1) - @pytest.mark.parametrize('model_id, exc_type, error_message', invalid_model_id_args) - def test_delete_model_validation_errors(self, model_id, exc_type, error_message): + @pytest.mark.parametrize('model_id, exc_type', invalid_model_id_args) + def test_delete_model_validation_errors(self, model_id, exc_type): with pytest.raises(exc_type) as err: mlkit.delete_model(model_id) - check_error(err.value, exc_type, error_message) + check_error(err.value, exc_type) def test_delete_model_error(self): recorder = instrument_mlkit_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) From cadd6c6d22ea0b6cc24914768df4d65cead9ea53 Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 27 Aug 2019 21:27:30 -0400 Subject: [PATCH 08/11] fixed lint --- firebase_admin/mlkit.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index b1e2810f8..51f5db48f 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -111,6 +111,7 @@ def display_name(self, display_name): @property def create_time(self): + """Returns the creation timestamp""" create_time = self._data.get('createTime') if not create_time: return None @@ -125,6 +126,7 @@ def create_time(self): @property def update_time(self): + """Returns the last update timestamp""" update_time = self._data.get('updateTime') if not update_time: return None From b02ea22b16d1e317f92f23a7a5aa3347e86fa35e Mon Sep 17 00:00:00 2001 From: ifielker Date: Wed, 28 Aug 2019 20:22:55 -0400 Subject: [PATCH 09/11] review comments --- firebase_admin/mlkit.py | 100 +++++++++++++++++---------------- tests/test_mlkit.py | 119 ++++++++++++++++++++-------------------- 2 files changed, 114 insertions(+), 105 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 51f5db48f..cedb7f662 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -70,23 +70,31 @@ def delete_model(model_id, app=None): class Model(object): - """A Firebase ML Kit Model object.""" + """A Firebase ML Kit Model object. + + Args: + display_name: String - The display name of your model - used to identify your model in code. + tags: Optional list of strings associated with your model. Can be used in list queries. + model_format: A subclass of ModelFormat. (e.g. TFLiteFormat) Specifies the model details. + kwargs: A set of keywords returned by an API response. + """ def __init__(self, display_name=None, tags=None, model_format=None, **kwargs): self._data = kwargs + self._model_format = None + tflite_format = self._data.pop('tfliteModel', None) + if tflite_format: + self._model_format = TFLiteFormat(**tflite_format) if display_name is not None: - self._data['displayName'] = _validate_display_name(display_name) + self.display_name = display_name if tags is not None: - self._data['tags'] = _validate_tags(tags) + self.tags = tags if model_format is not None: - _validate_model_format(model_format) - if isinstance(model_format, TFLiteFormat): - self._data['tfliteModel'] = model_format.as_dict() - else: - raise TypeError('Unsupported model format type.') + self.model_format = model_format def __eq__(self, other): if isinstance(other, self.__class__): - return self._data == other._data # pylint: disable=protected-access + # pylint: disable=protected-access + return self._data == other._data and self._model_format == other._model_format else: return False @@ -117,8 +125,6 @@ def create_time(self): return None seconds = create_time.get('seconds') - if not seconds: - return None if not isinstance(seconds, numbers.Number): return None @@ -132,8 +138,6 @@ def update_time(self): return None seconds = update_time.get('seconds') - if not seconds: - return None if not isinstance(seconds, numbers.Number): return None @@ -141,14 +145,11 @@ def update_time(self): @property def validation_error(self): - return self._data.get('state') and \ - self._data.get('state').get('validationError') and \ - self._data.get('state').get('validationError').get('message') + return self._data.get('state', {}).get('validationError', {}).get('message') @property def published(self): - return bool(self._data.get('state') and - self._data.get('state').get('published')) + return bool(self._data.get('state', {}).get('published')) @property def etag(self): @@ -174,19 +175,20 @@ def locked(self): @property def model_format(self): - if self._data.get('tfliteModel'): - return TFLiteFormat(**self._data.get('tfliteModel')) - return None + return self._model_format @model_format.setter def model_format(self, model_format): - if not isinstance(model_format, TFLiteFormat): - raise TypeError('Unsupported model format type.') - self._data['tfliteModel'] = model_format.as_dict() + if model_format is not None: + _validate_model_format(model_format) + self._model_format = model_format #Can be None return self def as_dict(self): - return self._data + copy = dict(self._data) + if self._model_format: + copy.update(self._model_format.as_dict()) + return copy class ModelFormat(object): @@ -196,22 +198,27 @@ def as_dict(self): class TFLiteFormat(ModelFormat): - """Model format representing a TFLite model.""" + """Model format representing a TFLite model. + + Args: + model_source: A TFLiteModelSource sub class. Specifies the details of the model source. + kwargs: A set of keywords returned by an API response + """ def __init__(self, model_source=None, **kwargs): self._data = kwargs + self._model_source = None + + gcs_tflite_uri = self._data.pop('gcsTfliteUri', None) + if gcs_tflite_uri: + self._model_source = TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri) + if model_source is not None: - # Check for correct base type - if not isinstance(model_source, TFLiteModelSource): - raise TypeError('Model source must be a ModelSource object.') - # Set based on specific sub type - if isinstance(model_source, TFLiteGCSModelSource): - self._data['gcsTfliteUri'] = model_source.as_dict() - else: - raise TypeError('Unsupported model source type.') + self.model_source = model_source def __eq__(self, other): if isinstance(other, self.__class__): - return self._data == other._data # pylint: disable=protected-access + # pylint: disable=protected-access + return self._data == other._data and self._model_source == other._model_source else: return False @@ -220,24 +227,24 @@ def __ne__(self, other): @property def model_source(self): - if self._data.get('gcsTfliteUri'): - return TFLiteGCSModelSource(self._data.get('gcsTfliteUri')) - return None + return self._model_source @model_source.setter def model_source(self, model_source): if model_source is not None: - if isinstance(model_source, TFLiteGCSModelSource): - self._data['gcsTfliteUri'] = model_source.as_dict() - else: - raise TypeError('Unsupported model source type.') + if not isinstance(model_source, TFLiteModelSource): + raise TypeError('Model source must be a TFLiteModelSource object.') + self._model_source = model_source # Can be None @property def size_bytes(self): return self._data.get('sizeBytes') def as_dict(self): - return self._data + copy = dict(self._data) + if self._model_source: + copy.update(self._model_source.as_dict()) + return {'tfliteModel': copy} class TFLiteModelSource(object): @@ -269,7 +276,7 @@ def gcs_tflite_uri(self, gcs_tflite_uri): self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) def as_dict(self): - return self._gcs_tflite_uri + return {"gcsTfliteUri": self._gcs_tflite_uri} #TODO(ifielker): implement from_saved_model etc. @@ -404,9 +411,8 @@ def _validate_gcs_tflite_uri(uri): return uri def _validate_model_format(model_format): - if model_format: - if not isinstance(model_format, ModelFormat): - raise TypeError('Model format must be a ModelFormat object.') + if not isinstance(model_format, ModelFormat): + raise TypeError('Model format must be a ModelFormat object.') return model_format def _validate_list_filter(list_filter): diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index c4d238024..3771e7f62 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -96,6 +96,7 @@ } GCS_TFLITE_URI = 'gs://my_bucket/mymodel.tflite' +GCS_TFLITE_URI_JSON = {'gcsTfliteUri': GCS_TFLITE_URI} GCS_TFLITE_MODEL_SOURCE = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI) TFLITE_FORMAT_JSON = { 'gcsTfliteUri': GCS_TFLITE_URI, @@ -104,6 +105,7 @@ TFLITE_FORMAT = mlkit.TFLiteFormat(**TFLITE_FORMAT_JSON) GCS_TFLITE_URI_2 = 'gs://my_bucket/mymodel2.tflite' +GCS_TFLITE_URI_JSON_2 = {'gcsTfliteUri': GCS_TFLITE_URI_2} GCS_TFLITE_MODEL_SOURCE_2 = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI_2) TFLITE_FORMAT_JSON_2 = { 'gcsTfliteUri': GCS_TFLITE_URI_2, @@ -172,40 +174,6 @@ } ERROR_RESPONSE_BAD_REQUEST = json.dumps(ERROR_JSON_BAD_REQUEST) -invalid_display_name_args = [ - ('', ValueError), - ('&_*#@:/?', ValueError), - (12345, TypeError) -] -invalid_tags_args = [ - ('tag1', TypeError, 'Tags must be a list of strings.'), - (123, TypeError, 'Tags must be a list of strings.'), - (['tag1', 123, 'tag2'], TypeError, 'Tags must be a list of strings.'), - (['tag1', '@#$%^&'], ValueError, 'Tag format is invalid.'), - (['', 'tag2'], ValueError, 'Tag format is invalid.'), - (['sixty-one_characters_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx', - 'tag2'], ValueError, 'Tag format is invalid.') -] -invalid_model_format_args = [ - (123, 'Model format must be a ModelFormat object.'), - (mlkit.ModelFormat(), 'Unsupported model format type.') -] -invalid_model_source_args = [ - (123, 'Model source must be a ModelSource object.'), - (mlkit.TFLiteModelSource(), 'Unsupported model source type.'), - -] -invalid_gcs_tflite_uri_args = [ - (123, TypeError), - ('abc', ValueError), - ('gs://NO_CAPITALS', ValueError), - ('gs://abc/', ValueError), - ('gs://aa/model.tflite', ValueError), - ('gs://@#$%/model.tflite', ValueError), - ('gs://invalid space/model.tflite', ValueError), - ('gs://sixty-four-characters_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx/model.tflite', - ValueError) -] invalid_model_id_args = [ ('', ValueError), ('&_*#@:/?', ValueError), @@ -214,16 +182,6 @@ ] PAGE_SIZE_VALUE_ERROR_MSG = 'Page size must be a positive integer between ' \ '1 and {0}'.format(mlkit._MAX_PAGE_SIZE) -invalid_page_size_args = [ - ('abc', TypeError, 'Page size must be a number or None.'), - (4.2, TypeError, 'Page size must be a number or None.'), - (list(), TypeError, 'Page size must be a number or None.'), - (dict(), TypeError, 'Page size must be a number or None.'), - (True, TypeError, 'Page size must be a number or None.'), - (-1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), - (0, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), - (mlkit._MAX_PAGE_SIZE + 1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG) -] invalid_string_or_none_args = [0, -1, 4.2, 0x10, False, list(), dict()] @@ -321,41 +279,77 @@ def test_model_source_setters(self): model_source = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI) model_source.gcs_tflite_uri = GCS_TFLITE_URI_2 assert model_source.gcs_tflite_uri == GCS_TFLITE_URI_2 - assert model_source.as_dict() == GCS_TFLITE_URI_2 + assert model_source.as_dict() == GCS_TFLITE_URI_JSON_2 def test_model_format_setters(self): model_format = mlkit.TFLiteFormat(model_source=GCS_TFLITE_MODEL_SOURCE) model_format.model_source = GCS_TFLITE_MODEL_SOURCE_2 assert model_format.model_source == GCS_TFLITE_MODEL_SOURCE_2 assert model_format.as_dict() == { - 'gcsTfliteUri': GCS_TFLITE_URI_2 + 'tfliteModel': { + 'gcsTfliteUri': GCS_TFLITE_URI_2 + } } - @pytest.mark.parametrize('display_name, exc_type', invalid_display_name_args) + @pytest.mark.parametrize('display_name, exc_type', [ + ('', ValueError), + ('&_*#@:/?', ValueError), + (12345, TypeError) + ]) def test_model_display_name_validation_errors(self, display_name, exc_type): with pytest.raises(exc_type) as err: mlkit.Model(display_name=display_name) check_error(err.value, exc_type) - @pytest.mark.parametrize('tags, exc_type, error_message', invalid_tags_args) + @pytest.mark.parametrize('tags, exc_type, error_message', [ + ('tag1', TypeError, 'Tags must be a list of strings.'), + (123, TypeError, 'Tags must be a list of strings.'), + (['tag1', 123, 'tag2'], TypeError, 'Tags must be a list of strings.'), + (['tag1', '@#$%^&'], ValueError, 'Tag format is invalid.'), + (['', 'tag2'], ValueError, 'Tag format is invalid.'), + (['sixty-one_characters_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx', + 'tag2'], ValueError, 'Tag format is invalid.') + ]) def test_model_tags_validation_errors(self, tags, exc_type, error_message): with pytest.raises(exc_type) as err: mlkit.Model(tags=tags) check_error(err.value, exc_type, error_message) - @pytest.mark.parametrize('model_format, error_message', invalid_model_format_args) - def test_model_format_validation_errors(self, model_format, error_message): + @pytest.mark.parametrize('model_format', [ + 123, + "abc", + {}, + [], + True + ]) + def test_model_format_validation_errors(self, model_format): with pytest.raises(TypeError) as err: mlkit.Model(model_format=model_format) - check_error(err.value, TypeError, error_message) - - @pytest.mark.parametrize('model_source, error_message', invalid_model_source_args) - def test_model_source_validation_errors(self, model_source, error_message): + check_error(err.value, TypeError, 'Model format must be a ModelFormat object.') + + @pytest.mark.parametrize('model_source', [ + 123, + "abc", + {}, + [], + True + ]) + def test_model_source_validation_errors(self, model_source): with pytest.raises(TypeError) as err: mlkit.TFLiteFormat(model_source=model_source) - check_error(err.value, TypeError, error_message) - - @pytest.mark.parametrize('uri, exc_type', invalid_gcs_tflite_uri_args) + check_error(err.value, TypeError, 'Model source must be a TFLiteModelSource object.') + + @pytest.mark.parametrize('uri, exc_type', [ + (123, TypeError), + ('abc', ValueError), + ('gs://NO_CAPITALS', ValueError), + ('gs://abc/', ValueError), + ('gs://aa/model.tflite', ValueError), + ('gs://@#$%/model.tflite', ValueError), + ('gs://invalid space/model.tflite', ValueError), + ('gs://sixty-four-characters_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx/model.tflite', + ValueError) + ]) def test_gcs_tflite_source_validation_errors(self, uri, exc_type): with pytest.raises(exc_type) as err: mlkit.TFLiteGCSModelSource(gcs_tflite_uri=uri) @@ -523,7 +517,16 @@ def test_list_models_list_filter_validation(self, list_filter): mlkit.list_models(list_filter=list_filter) check_error(err.value, TypeError, 'List filter must be a string or None.') - @pytest.mark.parametrize('page_size, exc_type, error_message', invalid_page_size_args) + @pytest.mark.parametrize('page_size, exc_type, error_message', [ + ('abc', TypeError, 'Page size must be a number or None.'), + (4.2, TypeError, 'Page size must be a number or None.'), + (list(), TypeError, 'Page size must be a number or None.'), + (dict(), TypeError, 'Page size must be a number or None.'), + (True, TypeError, 'Page size must be a number or None.'), + (-1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), + (0, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), + (mlkit._MAX_PAGE_SIZE + 1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG) + ]) def test_list_models_page_size_validation(self, page_size, exc_type, error_message): with pytest.raises(exc_type) as err: mlkit.list_models(page_size=page_size) From a0a2411a66d1c22f9ea713658397135316dafb34 Mon Sep 17 00:00:00 2001 From: ifielker Date: Thu, 29 Aug 2019 15:46:33 -0400 Subject: [PATCH 10/11] more review changes --- firebase_admin/mlkit.py | 58 +++++++++++++++++++++++------------------ tests/test_mlkit.py | 15 ++++++----- 2 files changed, 41 insertions(+), 32 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index cedb7f662..da2e42a73 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -55,7 +55,7 @@ def _get_mlkit_service(app): def get_model(model_id, app=None): mlkit_service = _get_mlkit_service(app) - return Model(**mlkit_service.get_model(model_id)) + return Model.from_dict(mlkit_service.get_model(model_id)) def list_models(list_filter=None, page_size=None, page_token=None, app=None): @@ -73,17 +73,14 @@ class Model(object): """A Firebase ML Kit Model object. Args: - display_name: String - The display name of your model - used to identify your model in code. + display_name: The display name of your model - used to identify your model in code. tags: Optional list of strings associated with your model. Can be used in list queries. model_format: A subclass of ModelFormat. (e.g. TFLiteFormat) Specifies the model details. - kwargs: A set of keywords returned by an API response. """ - def __init__(self, display_name=None, tags=None, model_format=None, **kwargs): - self._data = kwargs + def __init__(self, display_name=None, tags=None, model_format=None): + self._data = {} self._model_format = None - tflite_format = self._data.pop('tfliteModel', None) - if tflite_format: - self._model_format = TFLiteFormat(**tflite_format) + if display_name is not None: self.display_name = display_name if tags is not None: @@ -91,6 +88,17 @@ def __init__(self, display_name=None, tags=None, model_format=None, **kwargs): if model_format is not None: self.model_format = model_format + @classmethod + def from_dict(cls, data): + data_copy = dict(data) + tflite_format = None + tflite_format_data = data_copy.pop('tfliteModel', None) + if tflite_format_data: + tflite_format = TFLiteFormat.from_dict(tflite_format_data) + model = Model(model_format=tflite_format) + model._data = data_copy # pylint: disable=protected-access + return model + def __eq__(self, other): if isinstance(other, self.__class__): # pylint: disable=protected-access @@ -120,11 +128,7 @@ def display_name(self, display_name): @property def create_time(self): """Returns the creation timestamp""" - create_time = self._data.get('createTime') - if not create_time: - return None - - seconds = create_time.get('seconds') + seconds = self._data.get('createTime', {}).get('seconds') if not isinstance(seconds, numbers.Number): return None @@ -133,11 +137,7 @@ def create_time(self): @property def update_time(self): """Returns the last update timestamp""" - update_time = self._data.get('updateTime') - if not update_time: - return None - - seconds = update_time.get('seconds') + seconds = self._data.get('updateTime', {}).get('seconds') if not isinstance(seconds, numbers.Number): return None @@ -204,17 +204,25 @@ class TFLiteFormat(ModelFormat): model_source: A TFLiteModelSource sub class. Specifies the details of the model source. kwargs: A set of keywords returned by an API response """ - def __init__(self, model_source=None, **kwargs): - self._data = kwargs + def __init__(self, model_source=None): + self._data = {} self._model_source = None - gcs_tflite_uri = self._data.pop('gcsTfliteUri', None) - if gcs_tflite_uri: - self._model_source = TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri) - if model_source is not None: self.model_source = model_source + @classmethod + def from_dict(cls, data): + data_copy = dict(data) + model_source = None + gcs_tflite_uri = data_copy.pop('gcsTfliteUri', None) + if gcs_tflite_uri: + model_source = TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri) + tflite_format = TFLiteFormat(model_source=model_source) + tflite_format._data = data_copy # pylint: disable=protected-access + return tflite_format + + def __eq__(self, other): if isinstance(other, self.__class__): # pylint: disable=protected-access @@ -299,7 +307,7 @@ def __init__(self, list_models_func, list_filter, page_size, page_token): @property def models(self): """A list of Models from this page.""" - return [Model(**model) for model in self._list_response.get('models', [])] + return [Model.from_dict(model) for model in self._list_response.get('models', [])] @property def list_filter(self): diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 3771e7f62..c20982a2b 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -54,7 +54,7 @@ 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_1 } -MODEL_1 = mlkit.Model(**MODEL_JSON_1) +MODEL_1 = mlkit.Model.from_dict(MODEL_JSON_1) MODEL_ID_2 = 'modelId2' MODEL_NAME_2 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_2) @@ -63,7 +63,7 @@ 'name': MODEL_NAME_2, 'displayName': DISPLAY_NAME_2 } -MODEL_2 = mlkit.Model(**MODEL_JSON_2) +MODEL_2 = mlkit.Model.from_dict(MODEL_JSON_2) MODEL_ID_3 = 'modelId3' MODEL_NAME_3 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_3) @@ -72,7 +72,7 @@ 'name': MODEL_NAME_3, 'displayName': DISPLAY_NAME_3 } -MODEL_3 = mlkit.Model(**MODEL_JSON_3) +MODEL_3 = mlkit.Model.from_dict(MODEL_JSON_3) MODEL_STATE_PUBLISHED_JSON = { 'published': True @@ -102,7 +102,7 @@ 'gcsTfliteUri': GCS_TFLITE_URI, 'sizeBytes': '1234567' } -TFLITE_FORMAT = mlkit.TFLiteFormat(**TFLITE_FORMAT_JSON) +TFLITE_FORMAT = mlkit.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON) GCS_TFLITE_URI_2 = 'gs://my_bucket/mymodel2.tflite' GCS_TFLITE_URI_JSON_2 = {'gcsTfliteUri': GCS_TFLITE_URI_2} @@ -111,7 +111,7 @@ 'gcsTfliteUri': GCS_TFLITE_URI_2, 'sizeBytes': '2345678' } -TFLITE_FORMAT_2 = mlkit.TFLiteFormat(**TFLITE_FORMAT_JSON_2) +TFLITE_FORMAT_2 = mlkit.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_2) FULL_MODEL_ERR_STATE_LRO_JSON = { 'name': MODEL_NAME_1, @@ -215,7 +215,7 @@ class TestModel(object): """Tests mlkit.Model class.""" def test_model_success_err_state_lro(self): - model = mlkit.Model(**FULL_MODEL_ERR_STATE_LRO_JSON) + model = mlkit.Model.from_dict(FULL_MODEL_ERR_STATE_LRO_JSON) assert model.model_id == MODEL_ID_1 assert model.display_name == DISPLAY_NAME_1 assert model.create_time == CREATE_TIME_DATETIME @@ -230,7 +230,7 @@ def test_model_success_err_state_lro(self): assert model.as_dict() == FULL_MODEL_ERR_STATE_LRO_JSON def test_model_success_published(self): - model = mlkit.Model(**FULL_MODEL_PUBLISHED_JSON) + model = mlkit.Model.from_dict(FULL_MODEL_PUBLISHED_JSON) assert model.model_id == MODEL_ID_1 assert model.display_name == DISPLAY_NAME_1 assert model.create_time == CREATE_TIME_DATETIME @@ -355,6 +355,7 @@ def test_gcs_tflite_source_validation_errors(self, uri, exc_type): mlkit.TFLiteGCSModelSource(gcs_tflite_uri=uri) check_error(err.value, exc_type) + class TestGetModel(object): """Tests mlkit.get_model.""" @classmethod From fc63db856c521508b471ad983bd0952c614f7d5e Mon Sep 17 00:00:00 2001 From: ifielker Date: Thu, 29 Aug 2019 16:09:41 -0400 Subject: [PATCH 11/11] fixed lint --- firebase_admin/mlkit.py | 1 - 1 file changed, 1 deletion(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index da2e42a73..3f1a825f6 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -202,7 +202,6 @@ class TFLiteFormat(ModelFormat): Args: model_source: A TFLiteModelSource sub class. Specifies the details of the model source. - kwargs: A set of keywords returned by an API response """ def __init__(self, model_source=None): self._data = {}