diff --git a/integration/test_ml.py b/integration/test_ml.py new file mode 100644 index 000000000..4c44f3d10 --- /dev/null +++ b/integration/test_ml.py @@ -0,0 +1,373 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for firebase_admin.ml module.""" +import os +import random +import re +import shutil +import string +import tempfile +import pytest + + +from firebase_admin import exceptions +from firebase_admin import ml +from tests import testutils + + +# pylint: disable=import-error,no-name-in-module +try: + import tensorflow as tf + _TF_ENABLED = True +except ImportError: + _TF_ENABLED = False + + +def _random_identifier(prefix): + #pylint: disable=unused-variable + suffix = ''.join([random.choice(string.ascii_letters + string.digits) for n in range(8)]) + return '{0}_{1}'.format(prefix, suffix) + + +NAME_ONLY_ARGS = { + 'display_name': _random_identifier('TestModel123_') +} +NAME_ONLY_ARGS_UPDATED = { + 'display_name': _random_identifier('TestModel123_updated_') +} +NAME_AND_TAGS_ARGS = { + 'display_name': _random_identifier('TestModel123_tags_'), + 'tags': ['test_tag123'] +} +FULL_MODEL_ARGS = { + 'display_name': _random_identifier('TestModel123_full_'), + 'tags': ['test_tag567'], + 'file_name': 'model1.tflite' +} +INVALID_FULL_MODEL_ARGS = { + 'display_name': _random_identifier('TestModel123_invalid_full_'), + 'tags': ['test_tag890'], + 'file_name': 'invalid_model.tflite' +} + + +@pytest.fixture +def firebase_model(request): + args = request.param + tflite_format = None + file_name = args.get('file_name') + if file_name: + file_path = testutils.resource_filename(file_name) + source = ml.TFLiteGCSModelSource.from_tflite_model_file(file_path) + tflite_format = ml.TFLiteFormat(model_source=source) + + ml_model = ml.Model( + display_name=args.get('display_name'), + tags=args.get('tags'), + model_format=tflite_format) + model = ml.create_model(model=ml_model) + yield model + _clean_up_model(model) + + +@pytest.fixture +def model_list(): + ml_model_1 = ml.Model(display_name=_random_identifier('TestModel123_list1_')) + model_1 = ml.create_model(model=ml_model_1) + + ml_model_2 = ml.Model(display_name=_random_identifier('TestModel123_list2_'), + tags=['test_tag123']) + model_2 = ml.create_model(model=ml_model_2) + + yield [model_1, model_2] + + _clean_up_model(model_1) + _clean_up_model(model_2) + + +def _clean_up_model(model): + try: + # Try to delete the model. + # Some tests delete the model as part of the test. + ml.delete_model(model.model_id) + except exceptions.NotFoundError: + pass + + +# For rpc errors +def check_firebase_error(excinfo, status, msg): + err = excinfo.value + assert isinstance(err, exceptions.FirebaseError) + assert err.cause is not None + assert err.http_response is not None + assert err.http_response.status_code == status + assert str(err) == msg + + +# For operation errors +def check_operation_error(excinfo, msg): + err = excinfo.value + assert isinstance(err, exceptions.FirebaseError) + assert str(err) == msg + + +def check_model(model, args): + assert model.display_name == args.get('display_name') + assert model.tags == args.get('tags') + assert model.model_id is not None + assert model.create_time is not None + assert model.update_time is not None + assert model.locked is False + assert model.etag is not None + + +def check_model_format(model, has_model_format=False, validation_error=None): + if has_model_format: + assert model.validation_error == validation_error + assert model.published is False + assert model.model_format.model_source.gcs_tflite_uri.startswith('gs://') + if validation_error: + assert model.model_format.size_bytes is None + assert model.model_hash is None + else: + assert model.model_format.size_bytes is not None + assert model.model_hash is not None + else: + assert model.model_format is None + assert model.validation_error == 'No model file has been uploaded.' + assert model.published is False + assert model.model_hash is None + + +@pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True) +def test_create_simple_model(firebase_model): + check_model(firebase_model, NAME_AND_TAGS_ARGS) + check_model_format(firebase_model) + + +@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) +def test_create_full_model(firebase_model): + check_model(firebase_model, FULL_MODEL_ARGS) + check_model_format(firebase_model, True) + + +@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) +def test_create_already_existing_fails(firebase_model): + with pytest.raises(exceptions.AlreadyExistsError) as excinfo: + ml.create_model(model=firebase_model) + check_operation_error( + excinfo, + 'Model \'{0}\' already exists'.format(firebase_model.display_name)) + + +@pytest.mark.parametrize('firebase_model', [INVALID_FULL_MODEL_ARGS], indirect=True) +def test_create_invalid_model(firebase_model): + check_model(firebase_model, INVALID_FULL_MODEL_ARGS) + check_model_format(firebase_model, True, 'Invalid flatbuffer format') + + +@pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True) +def test_get_model(firebase_model): + get_model = ml.get_model(firebase_model.model_id) + check_model(get_model, NAME_AND_TAGS_ARGS) + check_model_format(get_model) + + +@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) +def test_get_non_existing_model(firebase_model): + # Get a valid model_id that no longer exists + ml.delete_model(firebase_model.model_id) + + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.get_model(firebase_model.model_id) + check_firebase_error(excinfo, 404, 'Requested entity was not found.') + + +@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) +def test_update_model(firebase_model): + new_model_name = NAME_ONLY_ARGS_UPDATED.get('display_name') + firebase_model.display_name = new_model_name + updated_model = ml.update_model(firebase_model) + check_model(updated_model, NAME_ONLY_ARGS_UPDATED) + check_model_format(updated_model) + + # Second call with same model does not cause error + updated_model2 = ml.update_model(updated_model) + check_model(updated_model2, NAME_ONLY_ARGS_UPDATED) + check_model_format(updated_model2) + + +@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) +def test_update_non_existing_model(firebase_model): + ml.delete_model(firebase_model.model_id) + + firebase_model.tags = ['tag987'] + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.update_model(firebase_model) + check_operation_error( + excinfo, + 'Model \'{0}\' was not found'.format(firebase_model.as_dict().get('name'))) + + +@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) +def test_publish_unpublish_model(firebase_model): + assert firebase_model.published is False + + published_model = ml.publish_model(firebase_model.model_id) + assert published_model.published is True + + unpublished_model = ml.unpublish_model(published_model.model_id) + assert unpublished_model.published is False + + +@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) +def test_publish_invalid_fails(firebase_model): + assert firebase_model.validation_error is not None + + with pytest.raises(exceptions.FailedPreconditionError) as excinfo: + ml.publish_model(firebase_model.model_id) + check_operation_error( + excinfo, + 'Cannot publish a model that is not verified.') + + +@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) +def test_publish_unpublish_non_existing_model(firebase_model): + ml.delete_model(firebase_model.model_id) + + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.publish_model(firebase_model.model_id) + check_operation_error( + excinfo, + 'Model \'{0}\' was not found'.format(firebase_model.as_dict().get('name'))) + + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.unpublish_model(firebase_model.model_id) + check_operation_error( + excinfo, + 'Model \'{0}\' was not found'.format(firebase_model.as_dict().get('name'))) + + +def test_list_models(model_list): + filter_str = 'displayName={0} OR tags:{1}'.format( + model_list[0].display_name, model_list[1].tags[0]) + + all_models = ml.list_models(list_filter=filter_str) + all_model_ids = [mdl.model_id for mdl in all_models.iterate_all()] + for mdl in model_list: + assert mdl.model_id in all_model_ids + + +def test_list_models_invalid_filter(): + invalid_filter = 'InvalidFilterParam=123' + + with pytest.raises(exceptions.InvalidArgumentError) as excinfo: + ml.list_models(list_filter=invalid_filter) + check_firebase_error(excinfo, 400, 'Request contains an invalid argument.') + + +@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) +def test_delete_model(firebase_model): + ml.delete_model(firebase_model.model_id) + + # Second delete of same model will fail + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.delete_model(firebase_model.model_id) + check_firebase_error(excinfo, 404, 'Requested entity was not found.') + + +# Test tensor flow conversion functions if tensor flow is enabled. +#'pip install tensorflow' in the environment if you want _TF_ENABLED = True +#'pip install tensorflow==2.0.0b' for version 2 etc. + + +def _clean_up_directory(save_dir): + if save_dir.startswith(tempfile.gettempdir()) and os.path.exists(save_dir): + shutil.rmtree(save_dir) + + +@pytest.fixture +def keras_model(): + assert _TF_ENABLED + x_array = [-1, 0, 1, 2, 3, 4] + y_array = [-3, -1, 1, 3, 5, 7] + model = tf.keras.models.Sequential( + [tf.keras.layers.Dense(units=1, input_shape=[1])]) + model.compile(optimizer='sgd', loss='mean_squared_error') + model.fit(x_array, y_array, epochs=3) + return model + + +@pytest.fixture +def saved_model_dir(keras_model): + assert _TF_ENABLED + # Make a new parent directory. The child directory must not exist yet. + # The child directory gets created by tf. If it exists, the tf call fails. + parent = tempfile.mkdtemp() + save_dir = os.path.join(parent, 'child') + + # different versions have different model conversion capability + # pick something that works for each version + if tf.version.VERSION.startswith('1.'): + tf.reset_default_graph() + x_var = tf.placeholder(tf.float32, (None, 3), name="x") + y_var = tf.multiply(x_var, x_var, name="y") + with tf.Session() as sess: + tf.saved_model.simple_save(sess, save_dir, {"x": x_var}, {"y": y_var}) + else: + # If it's not version 1.x or version 2.x we need to update the test. + assert tf.version.VERSION.startswith('2.') + tf.saved_model.save(keras_model, save_dir) + yield save_dir + _clean_up_directory(parent) + + +@pytest.mark.skipif(not _TF_ENABLED, reason='Tensor flow is required for this test.') +def test_from_keras_model(keras_model): + source = ml.TFLiteGCSModelSource.from_keras_model(keras_model, 'model2.tflite') + assert re.search( + '^gs://.*/Firebase/ML/Models/model2.tflite$', + source.gcs_tflite_uri) is not None + + # Validate the conversion by creating a model + model_format = ml.TFLiteFormat(model_source=source) + model = ml.Model(display_name=_random_identifier('KerasModel_'), model_format=model_format) + created_model = ml.create_model(model) + + try: + check_model(created_model, {'display_name': model.display_name}) + check_model_format(created_model, True) + finally: + _clean_up_model(created_model) + + +@pytest.mark.skipif(not _TF_ENABLED, reason='Tensor flow is required for this test.') +def test_from_saved_model(saved_model_dir): + # Test the conversion helper + source = ml.TFLiteGCSModelSource.from_saved_model(saved_model_dir, 'model3.tflite') + assert re.search( + '^gs://.*/Firebase/ML/Models/model3.tflite$', + source.gcs_tflite_uri) is not None + + # Validate the conversion by creating a model + model_format = ml.TFLiteFormat(model_source=source) + model = ml.Model(display_name=_random_identifier('SavedModel_'), model_format=model_format) + created_model = ml.create_model(model) + + try: + assert created_model.model_id is not None + assert created_model.validation_error is None + finally: + _clean_up_model(created_model) diff --git a/tests/data/invalid_model.tflite b/tests/data/invalid_model.tflite new file mode 100644 index 000000000..d8482f436 --- /dev/null +++ b/tests/data/invalid_model.tflite @@ -0,0 +1 @@ +This is not a tflite file. diff --git a/tests/data/model1.tflite b/tests/data/model1.tflite new file mode 100644 index 000000000..c4b71b7a2 Binary files /dev/null and b/tests/data/model1.tflite differ