diff --git a/firebase_admin/mlkit.py b/firebase_admin/ml.py similarity index 92% rename from firebase_admin/mlkit.py rename to firebase_admin/ml.py index bb277abf9..809ba9a41 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/ml.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Firebase ML Kit module. +"""Firebase ML module. This module contains functions for creating, updating, getting, listing, -deleting, publishing and unpublishing Firebase ML Kit models. +deleting, publishing and unpublishing Firebase ML models. """ @@ -46,7 +46,7 @@ except ImportError: _TF_ENABLED = False -_MLKIT_ATTRIBUTE = '_mlkit' +_ML_ATTRIBUTE = '_ml' _MAX_PAGE_SIZE = 100 _MODEL_ID_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') _DISPLAY_NAME_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') @@ -60,51 +60,51 @@ r'/operation/[^/]+$') -def _get_mlkit_service(app): - """ Returns an _MLKitService instance for an App. +def _get_ml_service(app): + """ Returns an _MLService instance for an App. Args: app: A Firebase App instance (or None to use the default App). Returns: - _MLKitService: An _MLKitService for the specified App instance. + _MLService: An _MLService for the specified App instance. Raises: ValueError: If the app argument is invalid. """ - return _utils.get_app_service(app, _MLKIT_ATTRIBUTE, _MLKitService) + return _utils.get_app_service(app, _ML_ATTRIBUTE, _MLService) def create_model(model, app=None): - """Creates a model in Firebase ML Kit. + """Creates a model in Firebase ML. Args: - model: An mlkit.Model to create. + model: An ml.Model to create. app: A Firebase app instance (or None to use the default app). Returns: - Model: The model that was created in Firebase ML Kit. + Model: The model that was created in Firebase ML. """ - mlkit_service = _get_mlkit_service(app) - return Model.from_dict(mlkit_service.create_model(model), app=app) + ml_service = _get_ml_service(app) + return Model.from_dict(ml_service.create_model(model), app=app) def update_model(model, app=None): - """Updates a model in Firebase ML Kit. + """Updates a model in Firebase ML. Args: - model: The mlkit.Model to update. + model: The ml.Model to update. app: A Firebase app instance (or None to use the default app). Returns: Model: The updated model. """ - mlkit_service = _get_mlkit_service(app) - return Model.from_dict(mlkit_service.update_model(model), app=app) + ml_service = _get_ml_service(app) + return Model.from_dict(ml_service.update_model(model), app=app) def publish_model(model_id, app=None): - """Publishes a model in Firebase ML Kit. + """Publishes a model in Firebase ML. Args: model_id: The id of the model to publish. @@ -113,12 +113,12 @@ def publish_model(model_id, app=None): Returns: Model: The published model. """ - mlkit_service = _get_mlkit_service(app) - return Model.from_dict(mlkit_service.set_published(model_id, publish=True), app=app) + ml_service = _get_ml_service(app) + return Model.from_dict(ml_service.set_published(model_id, publish=True), app=app) def unpublish_model(model_id, app=None): - """Unpublishes a model in Firebase ML Kit. + """Unpublishes a model in Firebase ML. Args: model_id: The id of the model to unpublish. @@ -127,12 +127,12 @@ def unpublish_model(model_id, app=None): Returns: Model: The unpublished model. """ - mlkit_service = _get_mlkit_service(app) - return Model.from_dict(mlkit_service.set_published(model_id, publish=False), app=app) + ml_service = _get_ml_service(app) + return Model.from_dict(ml_service.set_published(model_id, publish=False), app=app) def get_model(model_id, app=None): - """Gets a model from Firebase ML Kit. + """Gets a model from Firebase ML. Args: model_id: The id of the model to get. @@ -141,12 +141,12 @@ def get_model(model_id, app=None): Returns: Model: The requested model. """ - mlkit_service = _get_mlkit_service(app) - return Model.from_dict(mlkit_service.get_model(model_id), app=app) + ml_service = _get_ml_service(app) + return Model.from_dict(ml_service.get_model(model_id), app=app) def list_models(list_filter=None, page_size=None, page_token=None, app=None): - """Lists models from Firebase ML Kit. + """Lists models from Firebase ML. Args: list_filter: a list filter string such as "tags:'tag_1'". None will return all models. @@ -159,24 +159,24 @@ def list_models(list_filter=None, page_size=None, page_token=None, app=None): Returns: ListModelsPage: A (filtered) list of models. """ - mlkit_service = _get_mlkit_service(app) + ml_service = _get_ml_service(app) return ListModelsPage( - mlkit_service.list_models, list_filter, page_size, page_token, app=app) + ml_service.list_models, list_filter, page_size, page_token, app=app) def delete_model(model_id, app=None): - """Deletes a model from Firebase ML Kit. + """Deletes a model from Firebase ML. Args: model_id: The id of the model you wish to delete. app: A Firebase app instance (or None to use the default app). """ - mlkit_service = _get_mlkit_service(app) - mlkit_service.delete_model(model_id) + ml_service = _get_ml_service(app) + ml_service.delete_model(model_id) class Model(object): - """A Firebase ML Kit Model object. + """A Firebase ML Model object. Args: display_name: The display name of your model - used to identify your model in code. @@ -310,10 +310,10 @@ def wait_for_unlocked(self, max_time_seconds=None): """ if not self.locked: return - mlkit_service = _get_mlkit_service(self._app) + ml_service = _get_ml_service(self._app) op_name = self._data.get('activeOperations')[0].get('name') - model_dict = mlkit_service.handle_operation( - mlkit_service.get_operation(op_name), + model_dict = ml_service.handle_operation( + ml_service.get_operation(op_name), wait_for_operation=True, max_time_seconds=max_time_seconds) self._update_from_dict(model_dict) @@ -418,7 +418,7 @@ class _CloudStorageClient(object): """Cloud Storage helper class""" GCS_URI = 'gs://{0}/{1}' - BLOB_NAME = 'Firebase/MLKit/Models/{0}' + BLOB_NAME = 'Firebase/ML/Models/{0}' @staticmethod def _assert_gcs_enabled(): @@ -541,9 +541,9 @@ def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None): """ TFLiteGCSModelSource._assert_tf_enabled() tflite_model = TFLiteGCSModelSource._tf_convert_from_saved_model(saved_model_dir) - open('firebase_mlkit_model.tflite', 'wb').write(tflite_model) + open('firebase_ml_model.tflite', 'wb').write(tflite_model) return TFLiteGCSModelSource.from_tflite_model_file( - 'firebase_mlkit_model.tflite', bucket_name, app) + 'firebase_ml_model.tflite', bucket_name, app) @classmethod def from_keras_model(cls, keras_model, bucket_name=None, app=None): @@ -563,9 +563,9 @@ def from_keras_model(cls, keras_model, bucket_name=None, app=None): """ TFLiteGCSModelSource._assert_tf_enabled() tflite_model = TFLiteGCSModelSource._tf_convert_from_keras_model(keras_model) - open('firebase_mlkit_model.tflite', 'wb').write(tflite_model) + open('firebase_ml_model.tflite', 'wb').write(tflite_model) return TFLiteGCSModelSource.from_tflite_model_file( - 'firebase_mlkit_model.tflite', bucket_name, app) + 'firebase_ml_model.tflite', bucket_name, app) @property def gcs_tflite_uri(self): @@ -577,7 +577,7 @@ def gcs_tflite_uri(self, gcs_tflite_uri): self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) def _get_signed_gcs_tflite_uri(self): - """Signs the GCS uri, so the model file can be uploaded to Firebase ML Kit and verified.""" + """Signs the GCS uri, so the model file can be uploaded to Firebase ML and verified.""" return TFLiteGCSModelSource._STORAGE_CLIENT.sign_uri(self._gcs_tflite_uri, self._app) def as_dict(self, for_upload=False): @@ -697,7 +697,7 @@ def _validate_and_parse_name(name): def _validate_model(model, update_mask=None): if not isinstance(model, Model): - raise TypeError('Model must be an mlkit.Model.') + raise TypeError('Model must be an ml.Model.') if update_mask is None and not model.display_name: raise ValueError('Model must have a display name.') @@ -765,8 +765,8 @@ def _validate_page_token(page_token): raise TypeError('Page token must be a string or None.') -class _MLKitService(object): - """Firebase MLKit service.""" +class _MLService(object): + """Firebase ML service.""" PROJECT_URL = 'https://mlkit.googleapis.com/v1beta1/projects/{0}/' OPERATION_URL = 'https://mlkit.googleapis.com/v1beta1/' @@ -777,15 +777,15 @@ def __init__(self, app): self._project_id = app.project_id if not self._project_id: raise ValueError( - 'Project ID is required to access MLKit service. Either set the ' + 'Project ID is required to access ML service. Either set the ' 'projectId option, or use service account credentials.') - self._project_url = _MLKitService.PROJECT_URL.format(self._project_id) + self._project_url = _MLService.PROJECT_URL.format(self._project_id) self._client = _http_client.JsonHttpClient( credential=app.credential.get_credential(), base_url=self._project_url) self._operation_client = _http_client.JsonHttpClient( credential=app.credential.get_credential(), - base_url=_MLKitService.OPERATION_URL) + base_url=_MLService.OPERATION_URL) def get_operation(self, op_name): _validate_and_parse_operation_name(op_name) @@ -796,8 +796,8 @@ def get_operation(self, op_name): def _exponential_backoff(self, current_attempt, stop_time): """Sleeps for the appropriate amount of time. Or throws deadline exceeded.""" - delay_factor = pow(_MLKitService.POLL_EXPONENTIAL_BACKOFF_FACTOR, current_attempt) - wait_time_seconds = delay_factor * _MLKitService.POLL_BASE_WAIT_TIME_SECONDS + delay_factor = pow(_MLService.POLL_EXPONENTIAL_BACKOFF_FACTOR, current_attempt) + wait_time_seconds = delay_factor * _MLService.POLL_BASE_WAIT_TIME_SECONDS if stop_time is not None: max_seconds_left = (stop_time - datetime.datetime.now()).total_seconds() @@ -897,7 +897,7 @@ def get_model(self, model_id): raise _utils.handle_platform_error_from_requests(error) def list_models(self, list_filter, page_size, page_token): - """ lists Firebase ML Kit models.""" + """ lists Firebase ML models.""" _validate_list_filter(list_filter) _validate_page_size(page_size) _validate_page_token(page_token) diff --git a/tests/test_mlkit.py b/tests/test_ml.py similarity index 79% rename from tests/test_mlkit.py rename to tests/test_ml.py index dbe590673..e66507e88 100644 --- a/tests/test_mlkit.py +++ b/tests/test_ml.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Test cases for the firebase_admin.mlkit module.""" +"""Test cases for the firebase_admin.ml module.""" import datetime import json @@ -20,7 +20,7 @@ import firebase_admin from firebase_admin import exceptions -from firebase_admin import mlkit +from firebase_admin import ml from tests import testutils BASE_URL = 'https://mlkit.googleapis.com/v1beta1/' @@ -61,7 +61,7 @@ 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_1 } -MODEL_1 = mlkit.Model.from_dict(MODEL_JSON_1) +MODEL_1 = ml.Model.from_dict(MODEL_JSON_1) MODEL_ID_2 = 'modelId2' MODEL_NAME_2 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_2) @@ -70,7 +70,7 @@ 'name': MODEL_NAME_2, 'displayName': DISPLAY_NAME_2 } -MODEL_2 = mlkit.Model.from_dict(MODEL_JSON_2) +MODEL_2 = ml.Model.from_dict(MODEL_JSON_2) MODEL_ID_3 = 'modelId3' MODEL_NAME_3 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_3) @@ -79,7 +79,7 @@ 'name': MODEL_NAME_3, 'displayName': DISPLAY_NAME_3 } -MODEL_3 = mlkit.Model.from_dict(MODEL_JSON_3) +MODEL_3 = ml.Model.from_dict(MODEL_JSON_3) MODEL_STATE_PUBLISHED_JSON = { 'published': True @@ -107,12 +107,12 @@ GCS_BLOB_NAME = 'mymodel.tflite' GCS_TFLITE_URI = 'gs://{0}/{1}'.format(GCS_BUCKET_NAME, GCS_BLOB_NAME) GCS_TFLITE_URI_JSON = {'gcsTfliteUri': GCS_TFLITE_URI} -GCS_TFLITE_MODEL_SOURCE = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI) +GCS_TFLITE_MODEL_SOURCE = ml.TFLiteGCSModelSource(GCS_TFLITE_URI) TFLITE_FORMAT_JSON = { 'gcsTfliteUri': GCS_TFLITE_URI, 'sizeBytes': '1234567' } -TFLITE_FORMAT = mlkit.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON) +TFLITE_FORMAT = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON) GCS_TFLITE_SIGNED_URI_PATTERN = ( 'https://storage.googleapis.com/{0}/{1}?X-Goog-Algorithm=GOOG4-RSA-SHA256&foo') @@ -120,12 +120,12 @@ GCS_TFLITE_URI_2 = 'gs://my_bucket/mymodel2.tflite' GCS_TFLITE_URI_JSON_2 = {'gcsTfliteUri': GCS_TFLITE_URI_2} -GCS_TFLITE_MODEL_SOURCE_2 = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI_2) +GCS_TFLITE_MODEL_SOURCE_2 = ml.TFLiteGCSModelSource(GCS_TFLITE_URI_2) TFLITE_FORMAT_JSON_2 = { 'gcsTfliteUri': GCS_TFLITE_URI_2, 'sizeBytes': '2345678' } -TFLITE_FORMAT_2 = mlkit.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_2) +TFLITE_FORMAT_2 = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_2) CREATED_UPDATED_MODEL_JSON_1 = { 'name': MODEL_NAME_1, @@ -137,7 +137,7 @@ 'modelHash': MODEL_HASH, 'tags': TAGS, } -CREATED_UPDATED_MODEL_1 = mlkit.Model.from_dict(CREATED_UPDATED_MODEL_JSON_1) +CREATED_UPDATED_MODEL_1 = ml.Model.from_dict(CREATED_UPDATED_MODEL_JSON_1) LOCKED_MODEL_JSON_1 = { 'name': MODEL_NAME_1, @@ -202,7 +202,7 @@ 'tags': TAGS, 'tfliteModel': TFLITE_FORMAT_JSON } -FULL_MODEL_PUBLISHED = mlkit.Model.from_dict(FULL_MODEL_PUBLISHED_JSON) +FULL_MODEL_PUBLISHED = ml.Model.from_dict(FULL_MODEL_PUBLISHED_JSON) OPERATION_DONE_FULL_MODEL_PUBLISHED_JSON = { 'name': OPERATION_NAME_1, 'done': True, @@ -279,7 +279,7 @@ 'operations/project/1234/model/abc/operation/123/extrathing', ] PAGE_SIZE_VALUE_ERROR_MSG = 'Page size must be a positive integer between ' \ - '1 and {0}'.format(mlkit._MAX_PAGE_SIZE) + '1 and {0}'.format(ml._MAX_PAGE_SIZE) INVALID_STRING_OR_NONE_ARGS = [0, -1, 4.2, 0x10, False, list(), dict()] @@ -309,10 +309,10 @@ def check_firebase_error(excinfo, code, status, msg): assert str(err) == msg -def instrument_mlkit_service(status=200, payload=None, operations=False, app=None): +def instrument_ml_service(status=200, payload=None, operations=False, app=None): if not app: app = firebase_admin.get_app() - mlkit_service = mlkit._get_mlkit_service(app) + ml_service = ml._get_ml_service(app) recorder = [] session_url = 'https://mlkit.googleapis.com/v1beta1/' @@ -322,10 +322,10 @@ def instrument_mlkit_service(status=200, payload=None, operations=False, app=Non adapter = testutils.MockAdapter if operations: - mlkit_service._operation_client.session.mount( + ml_service._operation_client.session.mount( session_url, adapter(payload, status, recorder)) else: - mlkit_service._client.session.mount( + ml_service._client.session.mount( session_url, adapter(payload, status, recorder)) return recorder @@ -333,23 +333,23 @@ class _TestStorageClient(object): @staticmethod def upload(bucket_name, model_file_name, app): del app # unused variable - blob_name = mlkit._CloudStorageClient.BLOB_NAME.format(model_file_name) - return mlkit._CloudStorageClient.GCS_URI.format(bucket_name, blob_name) + blob_name = ml._CloudStorageClient.BLOB_NAME.format(model_file_name) + return ml._CloudStorageClient.GCS_URI.format(bucket_name, blob_name) @staticmethod def sign_uri(gcs_tflite_uri, app): del app # unused variable - bucket_name, blob_name = mlkit._CloudStorageClient._parse_gcs_tflite_uri(gcs_tflite_uri) + bucket_name, blob_name = ml._CloudStorageClient._parse_gcs_tflite_uri(gcs_tflite_uri) return GCS_TFLITE_SIGNED_URI_PATTERN.format(bucket_name, blob_name) class TestModel(object): - """Tests mlkit.Model class.""" + """Tests ml.Model class.""" @classmethod def setup_class(cls): cred = testutils.MockCredential() firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) - mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test - mlkit.TFLiteGCSModelSource._STORAGE_CLIENT = _TestStorageClient() + ml._MLService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test + ml.TFLiteGCSModelSource._STORAGE_CLIENT = _TestStorageClient() @classmethod def teardown_class(cls): @@ -361,7 +361,7 @@ def _op_url(project_id, model_id): 'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id) def test_model_success_err_state_lro(self): - model = mlkit.Model.from_dict(FULL_MODEL_ERR_STATE_LRO_JSON) + model = ml.Model.from_dict(FULL_MODEL_ERR_STATE_LRO_JSON) assert model.model_id == MODEL_ID_1 assert model.display_name == DISPLAY_NAME_1 assert model.create_time == CREATE_TIME_DATETIME @@ -376,7 +376,7 @@ def test_model_success_err_state_lro(self): assert model.as_dict() == FULL_MODEL_ERR_STATE_LRO_JSON def test_model_success_published(self): - model = mlkit.Model.from_dict(FULL_MODEL_PUBLISHED_JSON) + model = ml.Model.from_dict(FULL_MODEL_PUBLISHED_JSON) assert model.model_id == MODEL_ID_1 assert model.display_name == DISPLAY_NAME_1 assert model.create_time == CREATE_TIME_DATETIME @@ -391,7 +391,7 @@ def test_model_success_published(self): assert model.as_dict() == FULL_MODEL_PUBLISHED_JSON def test_model_keyword_based_creation_and_setters(self): - model = mlkit.Model(display_name=DISPLAY_NAME_1, tags=TAGS, model_format=TFLITE_FORMAT) + model = ml.Model(display_name=DISPLAY_NAME_1, tags=TAGS, model_format=TFLITE_FORMAT) assert model.display_name == DISPLAY_NAME_1 assert model.tags == TAGS assert model.model_format == TFLITE_FORMAT @@ -411,9 +411,9 @@ def test_model_keyword_based_creation_and_setters(self): } def test_model_format_source_creation(self): - model_source = mlkit.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) - model_format = mlkit.TFLiteFormat(model_source=model_source) - model = mlkit.Model(display_name=DISPLAY_NAME_1, model_format=model_format) + model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) + model_format = ml.TFLiteFormat(model_source=model_source) + model = ml.Model(display_name=DISPLAY_NAME_1, model_format=model_format) assert model.as_dict() == { 'displayName': DISPLAY_NAME_1, 'tfliteModel': { @@ -422,20 +422,20 @@ def test_model_format_source_creation(self): } def test_source_creation_from_tflite_file(self): - model_source = mlkit.TFLiteGCSModelSource.from_tflite_model_file( + model_source = ml.TFLiteGCSModelSource.from_tflite_model_file( "my_model.tflite", "my_bucket") assert model_source.as_dict() == { - 'gcsTfliteUri': 'gs://my_bucket/Firebase/MLKit/Models/my_model.tflite' + 'gcsTfliteUri': 'gs://my_bucket/Firebase/ML/Models/my_model.tflite' } def test_model_source_setters(self): - model_source = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI) + model_source = ml.TFLiteGCSModelSource(GCS_TFLITE_URI) model_source.gcs_tflite_uri = GCS_TFLITE_URI_2 assert model_source.gcs_tflite_uri == GCS_TFLITE_URI_2 assert model_source.as_dict() == GCS_TFLITE_URI_JSON_2 def test_model_format_setters(self): - model_format = mlkit.TFLiteFormat(model_source=GCS_TFLITE_MODEL_SOURCE) + model_format = ml.TFLiteFormat(model_source=GCS_TFLITE_MODEL_SOURCE) model_format.model_source = GCS_TFLITE_MODEL_SOURCE_2 assert model_format.model_source == GCS_TFLITE_MODEL_SOURCE_2 assert model_format.as_dict() == { @@ -445,9 +445,9 @@ def test_model_format_setters(self): } def test_model_as_dict_for_upload(self): - model_source = mlkit.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) - model_format = mlkit.TFLiteFormat(model_source=model_source) - model = mlkit.Model(display_name=DISPLAY_NAME_1, model_format=model_format) + model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) + model_format = ml.TFLiteFormat(model_source=model_source) + model = ml.Model(display_name=DISPLAY_NAME_1, model_format=model_format) assert model.as_dict(for_upload=True) == { 'displayName': DISPLAY_NAME_1, 'tfliteModel': { @@ -456,11 +456,11 @@ def test_model_as_dict_for_upload(self): } @pytest.mark.parametrize('helper_func', [ - mlkit.TFLiteGCSModelSource.from_keras_model, - mlkit.TFLiteGCSModelSource.from_saved_model + ml.TFLiteGCSModelSource.from_keras_model, + ml.TFLiteGCSModelSource.from_saved_model ]) def test_tf_not_enabled(self, helper_func): - mlkit._TF_ENABLED = False # for reliability + ml._TF_ENABLED = False # for reliability with pytest.raises(ImportError) as excinfo: helper_func(None) check_error(excinfo, ImportError) @@ -472,7 +472,7 @@ def test_tf_not_enabled(self, helper_func): ]) def test_model_display_name_validation_errors(self, display_name, exc_type): with pytest.raises(exc_type) as excinfo: - mlkit.Model(display_name=display_name) + ml.Model(display_name=display_name) check_error(excinfo, exc_type) @pytest.mark.parametrize('tags, exc_type, error_message', [ @@ -486,7 +486,7 @@ def test_model_display_name_validation_errors(self, display_name, exc_type): ]) def test_model_tags_validation_errors(self, tags, exc_type, error_message): with pytest.raises(exc_type) as excinfo: - mlkit.Model(tags=tags) + ml.Model(tags=tags) check_error(excinfo, exc_type, error_message) @pytest.mark.parametrize('model_format', [ @@ -498,7 +498,7 @@ def test_model_tags_validation_errors(self, tags, exc_type, error_message): ]) def test_model_format_validation_errors(self, model_format): with pytest.raises(TypeError) as excinfo: - mlkit.Model(model_format=model_format) + ml.Model(model_format=model_format) check_error(excinfo, TypeError, 'Model format must be a ModelFormat object.') @pytest.mark.parametrize('model_source', [ @@ -510,7 +510,7 @@ def test_model_format_validation_errors(self, model_format): ]) def test_model_source_validation_errors(self, model_source): with pytest.raises(TypeError) as excinfo: - mlkit.TFLiteFormat(model_source=model_source) + ml.TFLiteFormat(model_source=model_source) check_error(excinfo, TypeError, 'Model source must be a TFLiteModelSource object.') @pytest.mark.parametrize('uri, exc_type', [ @@ -526,18 +526,18 @@ def test_model_source_validation_errors(self, model_source): ]) def test_gcs_tflite_source_validation_errors(self, uri, exc_type): with pytest.raises(exc_type) as excinfo: - mlkit.TFLiteGCSModelSource(gcs_tflite_uri=uri) + ml.TFLiteGCSModelSource(gcs_tflite_uri=uri) check_error(excinfo, exc_type) def test_wait_for_unlocked_not_locked(self): - model = mlkit.Model(display_name="not_locked") + model = ml.Model(display_name="not_locked") model.wait_for_unlocked() def test_wait_for_unlocked(self): - recorder = instrument_mlkit_service(status=200, - operations=True, - payload=OPERATION_DONE_PUBLISHED_RESPONSE) - model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_1) + recorder = instrument_ml_service(status=200, + operations=True, + payload=OPERATION_DONE_PUBLISHED_RESPONSE) + model = ml.Model.from_dict(LOCKED_MODEL_JSON_1) model.wait_for_unlocked() assert model == FULL_MODEL_PUBLISHED assert len(recorder) == 1 @@ -545,10 +545,10 @@ def test_wait_for_unlocked(self): assert recorder[0].url == TestModel._op_url(PROJECT_ID, MODEL_ID_1) def test_wait_for_unlocked_timeout(self): - recorder = instrument_mlkit_service( + recorder = instrument_ml_service( status=200, operations=True, payload=OPERATION_NOT_DONE_RESPONSE) - mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 3 # longer so timeout applies immediately - model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_1) + ml._MLService.POLL_BASE_WAIT_TIME_SECONDS = 3 # longer so timeout applies immediately + model = ml.Model.from_dict(LOCKED_MODEL_JSON_1) with pytest.raises(Exception) as excinfo: model.wait_for_unlocked(max_time_seconds=0.1) check_error(excinfo, exceptions.DeadlineExceededError, 'Polling max time exceeded.') @@ -556,12 +556,12 @@ def test_wait_for_unlocked_timeout(self): class TestCreateModel(object): - """Tests mlkit.create_model.""" + """Tests ml.create_model.""" @classmethod def setup_class(cls): cred = testutils.MockCredential() firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) - mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test + ml._MLService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test @classmethod def teardown_class(cls): @@ -581,16 +581,16 @@ def _get_url(project_id, model_id): return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) def test_immediate_done(self): - instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE) - model = mlkit.create_model(MODEL_1) + instrument_ml_service(status=200, payload=OPERATION_DONE_RESPONSE) + model = ml.create_model(MODEL_1) assert model == CREATED_UPDATED_MODEL_1 def test_returns_locked(self): - recorder = instrument_mlkit_service( + recorder = instrument_ml_service( status=[200, 200], payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE]) - expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) - model = mlkit.create_model(MODEL_1) + expected_model = ml.Model.from_dict(LOCKED_MODEL_JSON_2) + model = ml.create_model(MODEL_1) assert model == expected_model assert len(recorder) == 2 @@ -600,23 +600,23 @@ def test_returns_locked(self): assert recorder[1].url == TestCreateModel._get_url(PROJECT_ID, MODEL_ID_1) def test_operation_error(self): - instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE) + instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE) with pytest.raises(Exception) as excinfo: - mlkit.create_model(MODEL_1) + ml.create_model(MODEL_1) # The http request succeeded, the operation returned contains a create failure check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) def test_malformed_operation(self): - instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) + instrument_ml_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) with pytest.raises(Exception) as excinfo: - mlkit.create_model(MODEL_1) + ml.create_model(MODEL_1) check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.') def test_rpc_error_create(self): - create_recorder = instrument_mlkit_service( + create_recorder = instrument_ml_service( status=400, payload=ERROR_RESPONSE_BAD_REQUEST) with pytest.raises(Exception) as excinfo: - mlkit.create_model(MODEL_1) + ml.create_model(MODEL_1) check_firebase_error( excinfo, ERROR_STATUS_BAD_REQUEST, @@ -628,36 +628,36 @@ def test_rpc_error_create(self): @pytest.mark.parametrize('model', INVALID_MODEL_ARGS) def test_not_model(self, model): with pytest.raises(Exception) as excinfo: - mlkit.create_model(model) - check_error(excinfo, TypeError, 'Model must be an mlkit.Model.') + ml.create_model(model) + check_error(excinfo, TypeError, 'Model must be an ml.Model.') def test_missing_display_name(self): with pytest.raises(Exception) as excinfo: - mlkit.create_model(mlkit.Model.from_dict({})) + ml.create_model(ml.Model.from_dict({})) check_error(excinfo, ValueError, 'Model must have a display name.') def test_missing_op_name(self): - instrument_mlkit_service(status=200, payload=OPERATION_MISSING_NAME_RESPONSE) + instrument_ml_service(status=200, payload=OPERATION_MISSING_NAME_RESPONSE) with pytest.raises(Exception) as excinfo: - mlkit.create_model(MODEL_1) + ml.create_model(MODEL_1) check_error(excinfo, TypeError) @pytest.mark.parametrize('op_name', INVALID_OP_NAME_ARGS) def test_invalid_op_name(self, op_name): payload = json.dumps({'name': op_name}) - instrument_mlkit_service(status=200, payload=payload) + instrument_ml_service(status=200, payload=payload) with pytest.raises(Exception) as excinfo: - mlkit.create_model(MODEL_1) + ml.create_model(MODEL_1) check_error(excinfo, ValueError, 'Operation name format is invalid.') class TestUpdateModel(object): - """Tests mlkit.update_model.""" + """Tests ml.update_model.""" @classmethod def setup_class(cls): cred = testutils.MockCredential() firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) - mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test + ml._MLService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test @classmethod def teardown_class(cls): @@ -673,16 +673,16 @@ def _op_url(project_id, model_id): 'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id) def test_immediate_done(self): - instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE) - model = mlkit.update_model(MODEL_1) + instrument_ml_service(status=200, payload=OPERATION_DONE_RESPONSE) + model = ml.update_model(MODEL_1) assert model == CREATED_UPDATED_MODEL_1 def test_returns_locked(self): - recorder = instrument_mlkit_service( + recorder = instrument_ml_service( status=[200, 200], payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE]) - expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) - model = mlkit.update_model(MODEL_1) + expected_model = ml.Model.from_dict(LOCKED_MODEL_JSON_2) + model = ml.update_model(MODEL_1) assert model == expected_model assert len(recorder) == 2 @@ -692,23 +692,23 @@ def test_returns_locked(self): assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) def test_operation_error(self): - instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE) + instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE) with pytest.raises(Exception) as excinfo: - mlkit.update_model(MODEL_1) + ml.update_model(MODEL_1) # The http request succeeded, the operation returned contains an update failure check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) def test_malformed_operation(self): - instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) + instrument_ml_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) with pytest.raises(Exception) as excinfo: - mlkit.update_model(MODEL_1) + ml.update_model(MODEL_1) check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.') def test_rpc_error(self): - create_recorder = instrument_mlkit_service( + create_recorder = instrument_ml_service( status=400, payload=ERROR_RESPONSE_BAD_REQUEST) with pytest.raises(Exception) as excinfo: - mlkit.update_model(MODEL_1) + ml.update_model(MODEL_1) check_firebase_error( excinfo, ERROR_STATUS_BAD_REQUEST, @@ -720,35 +720,35 @@ def test_rpc_error(self): @pytest.mark.parametrize('model', INVALID_MODEL_ARGS) def test_not_model(self, model): with pytest.raises(Exception) as excinfo: - mlkit.update_model(model) - check_error(excinfo, TypeError, 'Model must be an mlkit.Model.') + ml.update_model(model) + check_error(excinfo, TypeError, 'Model must be an ml.Model.') def test_missing_display_name(self): with pytest.raises(Exception) as excinfo: - mlkit.update_model(mlkit.Model.from_dict({})) + ml.update_model(ml.Model.from_dict({})) check_error(excinfo, ValueError, 'Model must have a display name.') def test_missing_op_name(self): - instrument_mlkit_service(status=200, payload=OPERATION_MISSING_NAME_RESPONSE) + instrument_ml_service(status=200, payload=OPERATION_MISSING_NAME_RESPONSE) with pytest.raises(Exception) as excinfo: - mlkit.update_model(MODEL_1) + ml.update_model(MODEL_1) check_error(excinfo, TypeError) @pytest.mark.parametrize('op_name', INVALID_OP_NAME_ARGS) def test_invalid_op_name(self, op_name): payload = json.dumps({'name': op_name}) - instrument_mlkit_service(status=200, payload=payload) + instrument_ml_service(status=200, payload=payload) with pytest.raises(Exception) as excinfo: - mlkit.update_model(MODEL_1) + ml.update_model(MODEL_1) check_error(excinfo, ValueError, 'Operation name format is invalid.') class TestPublishUnpublish(object): - """Tests mlkit.publish_model and mlkit.unpublish_model.""" + """Tests ml.publish_model and ml.unpublish_model.""" PUBLISH_UNPUBLISH_WITH_ARGS = [ - (mlkit.publish_model, True), - (mlkit.unpublish_model, False) + (ml.publish_model, True), + (ml.unpublish_model, False) ] PUBLISH_UNPUBLISH_FUNCS = [item[0] for item in PUBLISH_UNPUBLISH_WITH_ARGS] @@ -756,7 +756,7 @@ class TestPublishUnpublish(object): def setup_class(cls): cred = testutils.MockCredential() firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) - mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test + ml._MLService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test @classmethod def teardown_class(cls): @@ -779,7 +779,7 @@ def _op_url(project_id, model_id): @pytest.mark.parametrize('publish_function, published', PUBLISH_UNPUBLISH_WITH_ARGS) def test_immediate_done(self, publish_function, published): - recorder = instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE) + recorder = instrument_ml_service(status=200, payload=OPERATION_DONE_RESPONSE) model = publish_function(MODEL_ID_1) assert model == CREATED_UPDATED_MODEL_1 assert len(recorder) == 1 @@ -790,10 +790,10 @@ def test_immediate_done(self, publish_function, published): @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_returns_locked(self, publish_function): - recorder = instrument_mlkit_service( + recorder = instrument_ml_service( status=[200, 200], payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE]) - expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) + expected_model = ml.Model.from_dict(LOCKED_MODEL_JSON_2) model = publish_function(MODEL_ID_1) assert model == expected_model @@ -805,7 +805,7 @@ def test_returns_locked(self, publish_function): @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_operation_error(self, publish_function): - instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE) + instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE) with pytest.raises(Exception) as excinfo: publish_function(MODEL_ID_1) # The http request succeeded, the operation returned contains an update failure @@ -813,14 +813,14 @@ def test_operation_error(self, publish_function): @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_malformed_operation(self, publish_function): - instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) + instrument_ml_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) with pytest.raises(Exception) as excinfo: publish_function(MODEL_ID_1) check_error(excinfo, exceptions.UnknownError, 'Internal Error: Malformed Operation.') @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_rpc_error(self, publish_function): - create_recorder = instrument_mlkit_service( + create_recorder = instrument_ml_service( status=400, payload=ERROR_RESPONSE_BAD_REQUEST) with pytest.raises(Exception) as excinfo: publish_function(MODEL_ID_1) @@ -834,7 +834,7 @@ def test_rpc_error(self, publish_function): class TestGetModel(object): - """Tests mlkit.get_model.""" + """Tests ml.get_model.""" @classmethod def setup_class(cls): cred = testutils.MockCredential() @@ -849,8 +849,8 @@ def _url(project_id, model_id): return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) def test_get_model(self): - recorder = instrument_mlkit_service(status=200, payload=DEFAULT_GET_RESPONSE) - model = mlkit.get_model(MODEL_ID_1) + recorder = instrument_ml_service(status=200, payload=DEFAULT_GET_RESPONSE) + model = ml.get_model(MODEL_ID_1) assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1) @@ -861,13 +861,13 @@ def test_get_model(self): @pytest.mark.parametrize('model_id, exc_type', INVALID_MODEL_ID_ARGS) def test_get_model_validation_errors(self, model_id, exc_type): with pytest.raises(exc_type) as excinfo: - mlkit.get_model(model_id) + ml.get_model(model_id) check_error(excinfo, exc_type) def test_get_model_error(self): - recorder = instrument_mlkit_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) + recorder = instrument_ml_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) with pytest.raises(exceptions.NotFoundError) as excinfo: - mlkit.get_model(MODEL_ID_1) + ml.get_model(MODEL_ID_1) check_firebase_error( excinfo, ERROR_STATUS_NOT_FOUND, @@ -882,12 +882,12 @@ def test_no_project_id(self): def evaluate(): app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') with pytest.raises(ValueError): - mlkit.get_model(MODEL_ID_1, app) + ml.get_model(MODEL_ID_1, app) testutils.run_without_project_id(evaluate) class TestDeleteModel(object): - """Tests mlkit.delete_model.""" + """Tests ml.delete_model.""" @classmethod def setup_class(cls): cred = testutils.MockCredential() @@ -902,8 +902,8 @@ def _url(project_id, model_id): return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) def test_delete_model(self): - recorder = instrument_mlkit_service(status=200, payload=EMPTY_RESPONSE) - mlkit.delete_model(MODEL_ID_1) # no response for delete + recorder = instrument_ml_service(status=200, payload=EMPTY_RESPONSE) + ml.delete_model(MODEL_ID_1) # no response for delete assert len(recorder) == 1 assert recorder[0].method == 'DELETE' assert recorder[0].url == TestDeleteModel._url(PROJECT_ID, MODEL_ID_1) @@ -911,13 +911,13 @@ def test_delete_model(self): @pytest.mark.parametrize('model_id, exc_type', INVALID_MODEL_ID_ARGS) def test_delete_model_validation_errors(self, model_id, exc_type): with pytest.raises(exc_type) as excinfo: - mlkit.delete_model(model_id) + ml.delete_model(model_id) check_error(excinfo, exc_type) def test_delete_model_error(self): - recorder = instrument_mlkit_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) + recorder = instrument_ml_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) with pytest.raises(exceptions.NotFoundError) as excinfo: - mlkit.delete_model(MODEL_ID_1) + ml.delete_model(MODEL_ID_1) check_firebase_error( excinfo, ERROR_STATUS_NOT_FOUND, @@ -932,12 +932,12 @@ def test_no_project_id(self): def evaluate(): app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') with pytest.raises(ValueError): - mlkit.delete_model(MODEL_ID_1, app) + ml.delete_model(MODEL_ID_1, app) testutils.run_without_project_id(evaluate) class TestListModels(object): - """Tests mlkit.list_models.""" + """Tests ml.list_models.""" @classmethod def setup_class(cls): cred = testutils.MockCredential() @@ -953,14 +953,14 @@ def _url(project_id): @staticmethod def _check_page(page, model_count): - assert isinstance(page, mlkit.ListModelsPage) + assert isinstance(page, ml.ListModelsPage) assert len(page.models) == model_count for model in page.models: - assert isinstance(model, mlkit.Model) + assert isinstance(model, ml.Model) def test_list_models_no_args(self): - recorder = instrument_mlkit_service(status=200, payload=DEFAULT_LIST_RESPONSE) - models_page = mlkit.list_models() + recorder = instrument_ml_service(status=200, payload=DEFAULT_LIST_RESPONSE) + models_page = ml.list_models() assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == TestListModels._url(PROJECT_ID) @@ -971,8 +971,8 @@ def test_list_models_no_args(self): assert models_page.models[1] == MODEL_2 def test_list_models_with_all_args(self): - recorder = instrument_mlkit_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) - models_page = mlkit.list_models( + recorder = instrument_ml_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) + models_page = ml.list_models( 'display_name=displayName3', page_size=10, page_token=PAGE_TOKEN) @@ -982,7 +982,7 @@ def test_list_models_with_all_args(self): TestListModels._url(PROJECT_ID) + '?filter=display_name%3DdisplayName3&page_size=10&page_token={0}' .format(PAGE_TOKEN)) - assert isinstance(models_page, mlkit.ListModelsPage) + assert isinstance(models_page, ml.ListModelsPage) assert len(models_page.models) == 1 assert models_page.models[0] == MODEL_3 assert not models_page.has_next_page @@ -990,7 +990,7 @@ def test_list_models_with_all_args(self): @pytest.mark.parametrize('list_filter', INVALID_STRING_OR_NONE_ARGS) def test_list_models_list_filter_validation(self, list_filter): with pytest.raises(TypeError) as excinfo: - mlkit.list_models(list_filter=list_filter) + ml.list_models(list_filter=list_filter) check_error(excinfo, TypeError, 'List filter must be a string or None.') @pytest.mark.parametrize('page_size, exc_type, error_message', [ @@ -1001,23 +1001,23 @@ def test_list_models_list_filter_validation(self, list_filter): (True, TypeError, 'Page size must be a number or None.'), (-1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), (0, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), - (mlkit._MAX_PAGE_SIZE + 1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG) + (ml._MAX_PAGE_SIZE + 1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG) ]) def test_list_models_page_size_validation(self, page_size, exc_type, error_message): with pytest.raises(exc_type) as excinfo: - mlkit.list_models(page_size=page_size) + ml.list_models(page_size=page_size) check_error(excinfo, exc_type, error_message) @pytest.mark.parametrize('page_token', INVALID_STRING_OR_NONE_ARGS) def test_list_models_page_token_validation(self, page_token): with pytest.raises(TypeError) as excinfo: - mlkit.list_models(page_token=page_token) + ml.list_models(page_token=page_token) check_error(excinfo, TypeError, 'Page token must be a string or None.') def test_list_models_error(self): - recorder = instrument_mlkit_service(status=400, payload=ERROR_RESPONSE_BAD_REQUEST) + recorder = instrument_ml_service(status=400, payload=ERROR_RESPONSE_BAD_REQUEST) with pytest.raises(exceptions.InvalidArgumentError) as excinfo: - mlkit.list_models() + ml.list_models() check_firebase_error( excinfo, ERROR_STATUS_BAD_REQUEST, @@ -1032,12 +1032,12 @@ def test_no_project_id(self): def evaluate(): app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') with pytest.raises(ValueError): - mlkit.list_models(app=app) + ml.list_models(app=app) testutils.run_without_project_id(evaluate) def test_list_single_page(self): - recorder = instrument_mlkit_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) - models_page = mlkit.list_models() + recorder = instrument_ml_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) + models_page = ml.list_models() assert len(recorder) == 1 assert models_page.next_page_token == '' assert models_page.has_next_page is False @@ -1047,15 +1047,15 @@ def test_list_single_page(self): def test_list_multiple_pages(self): # Page 1 - recorder = instrument_mlkit_service(status=200, payload=DEFAULT_LIST_RESPONSE) - page = mlkit.list_models() + recorder = instrument_ml_service(status=200, payload=DEFAULT_LIST_RESPONSE) + page = ml.list_models() assert len(recorder) == 1 assert len(page.models) == 2 assert page.next_page_token == NEXT_PAGE_TOKEN assert page.has_next_page is True # Page 2 - recorder = instrument_mlkit_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) + recorder = instrument_ml_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) page_2 = page.get_next_page() assert len(recorder) == 1 assert len(page_2.models) == 1 @@ -1065,8 +1065,8 @@ def test_list_multiple_pages(self): def test_list_models_paged_iteration(self): # Page 1 - recorder = instrument_mlkit_service(status=200, payload=DEFAULT_LIST_RESPONSE) - page = mlkit.list_models() + recorder = instrument_ml_service(status=200, payload=DEFAULT_LIST_RESPONSE) + page = ml.list_models() assert page.next_page_token == NEXT_PAGE_TOKEN assert page.has_next_page is True iterator = page.iterate_all() @@ -1076,15 +1076,15 @@ def test_list_models_paged_iteration(self): assert len(recorder) == 1 # Page 2 - recorder = instrument_mlkit_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) + recorder = instrument_ml_service(status=200, payload=LAST_PAGE_LIST_RESPONSE) model = next(iterator) assert model.display_name == DISPLAY_NAME_3 with pytest.raises(StopIteration): next(iterator) def test_list_models_stop_iteration(self): - recorder = instrument_mlkit_service(status=200, payload=ONE_PAGE_LIST_RESPONSE) - page = mlkit.list_models() + recorder = instrument_ml_service(status=200, payload=ONE_PAGE_LIST_RESPONSE) + page = ml.list_models() assert len(recorder) == 1 assert len(page.models) == 3 iterator = page.iterate_all() @@ -1095,8 +1095,8 @@ def test_list_models_stop_iteration(self): assert len(models) == 3 def test_list_models_no_models(self): - recorder = instrument_mlkit_service(status=200, payload=NO_MODELS_LIST_RESPONSE) - page = mlkit.list_models() + recorder = instrument_ml_service(status=200, payload=NO_MODELS_LIST_RESPONSE) + page = ml.list_models() assert len(recorder) == 1 assert len(page.models) == 0 models = [model for model in page.iterate_all()]