From 7a7a2a58ffbeed230b3ce93924ade607430c2ee2 Mon Sep 17 00:00:00 2001 From: ifielker Date: Mon, 27 Jan 2020 18:03:28 -0500 Subject: [PATCH 1/4] Integration tests for Firebase ML --- integration/test_ml.py | 389 ++++++++++++++++++++++++++++++++ tests/data/invalid_model.tflite | 1 + tests/data/model1.tflite | Bin 0 -> 736 bytes 3 files changed, 390 insertions(+) create mode 100644 integration/test_ml.py create mode 100644 tests/data/invalid_model.tflite create mode 100644 tests/data/model1.tflite diff --git a/integration/test_ml.py b/integration/test_ml.py new file mode 100644 index 000000000..29a338704 --- /dev/null +++ b/integration/test_ml.py @@ -0,0 +1,389 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for firebase_admin.ml module.""" +import re +import pytest + + +from firebase_admin import ml +from firebase_admin import exceptions +from tests import testutils + + +# pylint: disable=import-error,no-name-in-module +try: + import tensorflow as tf + import os # This is only needed for the tensorflow testing + import shutil # This is only needed for the tensorflow testing + _TF_ENABLED = True +except ImportError: + _TF_ENABLED = False + + +@pytest.fixture +def name_only_model(): + model = ml.Model(display_name="TestModel123") + yield model + + +@pytest.fixture +def name_and_tags_model(): + model = ml.Model(display_name="TestModel123_tags", tags=['test_tag123']) + yield model + + +@pytest.fixture +def full_model(): + tflite_file_name = testutils.resource_filename('model1.tflite') + source1 = ml.TFLiteGCSModelSource.from_tflite_model_file(tflite_file_name) + format1 = ml.TFLiteFormat(model_source=source1) + model = ml.Model( + display_name="TestModel123_full", + tags=['test_tag567'], + model_format=format1) + yield model + + +@pytest.fixture +def invalid_full_model(): + tflite_file_name = testutils.resource_filename('invalid_model.tflite') + source1 = ml.TFLiteGCSModelSource.from_tflite_model_file(tflite_file_name) + format1 = ml.TFLiteFormat(model_source=source1) + model = ml.Model( + display_name="TestModel123_invalid_full", + tags=['test_tag890'], + model_format=format1) + yield model + + +# For rpc errors +def check_firebase_error(excinfo, status, msg): + err = excinfo.value + assert isinstance(err, exceptions.FirebaseError) + assert err.cause is not None + assert err.http_response is not None + assert err.http_response.status_code == status + assert str(err) == msg + + +# For operation errors +def check_operation_error(excinfo, msg): + err = excinfo.value + assert isinstance(err, exceptions.FirebaseError) + assert str(err) == msg + + +def _ensure_model_exists(model): + # Delete any previously existing model with the same name because + # it may be modified from the model that is passed in. + _delete_if_exists(model) + + # And recreate using the model passed in + created_model = ml.create_model(model=model) + return created_model + + +# Use this when you know the model_id and are sure it exists. +def _clean_up_model(model): + ml.delete_model(model.model_id) + + +# Use this when you don't know the model_id or it may not exist. +def _delete_if_exists(model): + filter_str = 'displayName={0}'.format(model.display_name) + models_list = ml.list_models(list_filter=filter_str) + for mdl in models_list.models: + ml.delete_model(mdl.model_id) + + +def test_create_simple_model(name_and_tags_model): + _delete_if_exists(name_and_tags_model) + + firebase_model = ml.create_model(model=name_and_tags_model) + assert firebase_model.display_name == name_and_tags_model.display_name + assert firebase_model.tags == name_and_tags_model.tags + assert firebase_model.model_id is not None + assert firebase_model.create_time is not None + assert firebase_model.update_time is not None + assert firebase_model.validation_error == 'No model file has been uploaded.' + assert firebase_model.locked is False + assert firebase_model.published is False + assert firebase_model.etag is not None + assert firebase_model.model_hash is None + + _clean_up_model(firebase_model) + +def test_create_full_model(full_model): + _delete_if_exists(full_model) + + firebase_model = ml.create_model(model=full_model) + assert firebase_model.display_name == full_model.display_name + assert firebase_model.tags == full_model.tags + assert firebase_model.model_format.size_bytes is not None + assert firebase_model.model_format.model_source == full_model.model_format.model_source + assert firebase_model.model_id is not None + assert firebase_model.create_time is not None + assert firebase_model.update_time is not None + assert firebase_model.validation_error is None + assert firebase_model.locked is False + assert firebase_model.published is False + assert firebase_model.etag is not None + assert firebase_model.model_hash is not None + + _clean_up_model(firebase_model) + + +def test_create_already_existing_fails(full_model): + _ensure_model_exists(full_model) + with pytest.raises(exceptions.AlreadyExistsError) as excinfo: + ml.create_model(model=full_model) + check_operation_error( + excinfo, + 'Model \'{0}\' already exists'.format(full_model.display_name)) + + +def test_create_invalid_model(invalid_full_model): + _delete_if_exists(invalid_full_model) + + firebase_model = ml.create_model(model=invalid_full_model) + assert firebase_model.display_name == invalid_full_model.display_name + assert firebase_model.tags == invalid_full_model.tags + assert firebase_model.model_format.size_bytes is None + assert firebase_model.model_format.model_source == invalid_full_model.model_format.model_source + assert firebase_model.model_id is not None + assert firebase_model.create_time is not None + assert firebase_model.update_time is not None + assert firebase_model.validation_error == 'Invalid flatbuffer format' + assert firebase_model.locked is False + assert firebase_model.published is False + assert firebase_model.etag is not None + assert firebase_model.model_hash is None + + _clean_up_model(firebase_model) + +def test_get_model(name_only_model): + existing_model = _ensure_model_exists(name_only_model) + + firebase_model = ml.get_model(existing_model.model_id) + assert firebase_model.display_name == name_only_model.display_name + assert firebase_model.model_id is not None + assert firebase_model.create_time is not None + assert firebase_model.update_time is not None + assert firebase_model.validation_error == 'No model file has been uploaded.' + assert firebase_model.etag is not None + assert firebase_model.locked is False + assert firebase_model.published is False + assert firebase_model.model_hash is None + + _clean_up_model(firebase_model) + + +def test_get_non_existing_model(name_only_model): + # Get a valid model_id that no longer exists + model = _ensure_model_exists(name_only_model) + ml.delete_model(model.model_id) + + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.get_model(model.model_id) + check_firebase_error(excinfo, 404, 'Requested entity was not found.') + + +def test_update_model(name_only_model): + new_model_name = 'TestModel123_updated' + _delete_if_exists(ml.Model(display_name=new_model_name)) + existing_model = _ensure_model_exists(name_only_model) + existing_model.display_name = new_model_name + + firebase_model = ml.update_model(existing_model) + assert firebase_model.display_name == new_model_name + assert firebase_model.model_id == existing_model.model_id + assert firebase_model.create_time == existing_model.create_time + assert firebase_model.update_time != existing_model.update_time + assert firebase_model.validation_error == existing_model.validation_error + assert firebase_model.etag != existing_model.etag + assert firebase_model.published == existing_model.published + assert firebase_model.locked == existing_model.locked + + # Second call with same model does not cause error + firebase_model2 = ml.update_model(firebase_model) + assert firebase_model2.display_name == firebase_model.display_name + assert firebase_model2.model_id == firebase_model.model_id + assert firebase_model2.create_time == firebase_model.create_time + assert firebase_model2.update_time != firebase_model.update_time + assert firebase_model2.validation_error == firebase_model.validation_error + assert firebase_model2.etag != existing_model.etag + assert firebase_model2.published == firebase_model.published + assert firebase_model2.locked == firebase_model.locked + + _clean_up_model(firebase_model) + + +def test_update_non_existing_model(name_only_model): + model = _ensure_model_exists(name_only_model) + ml.delete_model(model.model_id) + + model.tags = ['tag987'] + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.update_model(model) + check_operation_error( + excinfo, + 'Model \'{0}\' was not found'.format(model.as_dict().get('name'))) + +def test_publish_unpublish_model(full_model): + model = _ensure_model_exists(full_model) + assert model.published is False + + published_model = ml.publish_model(model.model_id) + assert published_model.published is True + + unpublished_model = ml.unpublish_model(published_model.model_id) + assert unpublished_model.published is False + + _clean_up_model(unpublished_model) + + +def test_publish_invalid_fails(name_only_model): + model = _ensure_model_exists(name_only_model) + assert model.validation_error is not None + + with pytest.raises(exceptions.FailedPreconditionError) as excinfo: + ml.publish_model(model.model_id) + check_operation_error( + excinfo, + 'Cannot publish a model that is not verified.') + + +def test_publish_unpublish_non_existing_model(full_model): + model = _ensure_model_exists(full_model) + ml.delete_model(model.model_id) + + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.publish_model(model.model_id) + check_operation_error( + excinfo, + 'Model \'{0}\' was not found'.format(model.as_dict().get('name'))) + + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.unpublish_model(model.model_id) + check_operation_error( + excinfo, + 'Model \'{0}\' was not found'.format(model.as_dict().get('name'))) + + +def test_list_models(name_only_model, name_and_tags_model): + existing_model1 = _ensure_model_exists(name_only_model) + existing_model2 = _ensure_model_exists(name_and_tags_model) + filter_str = 'displayName={0} OR tags:{1}'.format( + existing_model1.display_name, existing_model2.tags[0]) + + models_list = ml.list_models(list_filter=filter_str) + assert len(models_list.models) == 2 + for mdl in models_list.models: + assert mdl == existing_model1 or mdl == existing_model2 + assert models_list.models[0] != models_list.models[1] + + _clean_up_model(existing_model1) + _clean_up_model(existing_model2) + + +def test_list_models_invalid_filter(): + invalid_filter = 'InvalidFilterParam=123' + + with pytest.raises(exceptions.InvalidArgumentError) as excinfo: + ml.list_models(list_filter=invalid_filter) + check_firebase_error(excinfo, 400, 'Request contains an invalid argument.') + + +def test_delete_model(name_only_model): + existing_model = _ensure_model_exists(name_only_model) + + ml.delete_model(existing_model.model_id) + + # Second delete of same model will fail + with pytest.raises(exceptions.NotFoundError) as excinfo: + ml.delete_model(existing_model.model_id) + check_firebase_error(excinfo, 404, 'Requested entity was not found.') + + +#'pip install tensorflow' in the environment if you want _TF_ENABLED = True +#'pip install tensorflow=2.0.0' for version 2 etc. +if _TF_ENABLED: + # Test tensor flow conversion functions if tensor flow is enabled. + SAVED_MODEL_DIR = '/tmp/saved_model/1' + + def _clean_up_tmp_directory(): + if os.path.exists(SAVED_MODEL_DIR): + shutil.rmtree(SAVED_MODEL_DIR) + + @pytest.fixture + def keras_model(): + x_array = [-1, 0, 1, 2, 3, 4] + y_array = [-3, -1, 1, 3, 5, 7] + model = tf.keras.models.Sequential( + [tf.keras.layers.Dense(units=1, input_shape=[1])]) + model.compile(optimizer='sgd', loss='mean_squared_error') + model.fit(x_array, y_array, epochs=3) + yield model + + @pytest.fixture + def saved_model_dir(keras_model): + # different versions have different model conversion capability + # pick something that works for each version + save_dir = SAVED_MODEL_DIR + _clean_up_tmp_directory() # previous failures may leave files + if tf.version.VERSION.startswith('1.'): + tf.reset_default_graph() + x_var = tf.placeholder(tf.float32, (None, 3), name="x") + y_var = tf.multiply(x_var, x_var, name="y") + with tf.Session() as sess: + tf.saved_model.simple_save(sess, save_dir, {"x": x_var}, {"y": y_var}) + else: + # If it's not version 1.x or version 2.x we need to update the test. + assert tf.version.VERSION.startswith('2.') + tf.saved_model.save(keras_model, save_dir) + yield save_dir + + + def test_from_keras_model(keras_model): + source1 = ml.TFLiteGCSModelSource.from_keras_model(keras_model, 'model2.tflite') + assert re.search( + '^gs://.*/Firebase/ML/Models/model2.tflite$', + source1.gcs_tflite_uri) is not None + + # Validate the conversion by creating a model + format1 = ml.TFLiteFormat(model_source=source1) + model1 = ml.Model(display_name="KerasModel1", model_format=format1) + firebase_model = ml.create_model(model1) + assert firebase_model.model_id is not None + assert firebase_model.validation_error is None + + _clean_up_model(firebase_model) + + def test_from_saved_model(saved_model_dir): + # Test the conversion helper + source1 = ml.TFLiteGCSModelSource.from_saved_model(saved_model_dir, 'model3.tflite') + assert re.search( + '^gs://.*/Firebase/ML/Models/model3.tflite$', + source1.gcs_tflite_uri) is not None + + # Validate the conversion by creating a model + format1 = ml.TFLiteFormat(model_source=source1) + model1 = ml.Model(display_name="SavedModel1", model_format=format1) + firebase_model = ml.create_model(model1) + assert firebase_model.model_id is not None + assert firebase_model.validation_error is None + + _clean_up_model(firebase_model) + _clean_up_tmp_directory() diff --git a/tests/data/invalid_model.tflite b/tests/data/invalid_model.tflite new file mode 100644 index 000000000..d8482f436 --- /dev/null +++ b/tests/data/invalid_model.tflite @@ -0,0 +1 @@ +This is not a tflite file. diff --git a/tests/data/model1.tflite b/tests/data/model1.tflite new file mode 100644 index 0000000000000000000000000000000000000000..c4b71b7a222ebc59ee9fa1239fe2b8efb382cf8b GIT binary patch literal 736 zcmaJY5FPb2r$h~!B1MWTQ%GV^!A6@CK}ZNlustqhi-Y8H-iIg%{+S@QcK(Dk zR)Rl3tSqc7Y;=9^?h;aY@NQ;j_RY-O-KvOmPg{E;TT&H6Oeso9%7|7F5m^Gpi-MRS zFR}v^fd$|pw*l-X(CyeA%O3exDvVXXE-Q$g0Ea*gAfI&%;7x1IS|70Au#+FH4KSD! zSQ8%k6Eu29ZW(^Feo)_K>{n|Oma%PM==n~V_^~%s4thu4$j6N3nHtV(0aQiaEoyRp zezXMpUIWy`Jtl%{m~A>Qxh()kG2`sRkJM$N(Aph1%|>7Ok%DczaXT3_&XwE0a6`}S z4OAy+#G&g)!6;Io$rCiN=X3S*IL!O-tl7r`=KFA-Gt`c~_y(?gfqS2GxQ`s3wFx?^MfSDns>_^X1%E{Y8Sb)GfRIXJ-KXm39D=`^W;!7eZm6%(eLy;H^LSfV^(T? zeSA40uL6|QKPM`rbs3=!d?x$wt<-YMb0LrVras*CGoXmIndd!cZ>NyH9V}M=02G>W Ai~s-t literal 0 HcmV?d00001 From 99ccb6d43a8a4b2405d1c2a67e8bfe7d02971b73 Mon Sep 17 00:00:00 2001 From: ifielker Date: Wed, 29 Jan 2020 13:36:37 -0500 Subject: [PATCH 2/4] addressed review comments --- integration/test_ml.py | 451 ++++++++++++++++++++--------------------- 1 file changed, 221 insertions(+), 230 deletions(-) diff --git a/integration/test_ml.py b/integration/test_ml.py index 29a338704..53e4abcc5 100644 --- a/integration/test_ml.py +++ b/integration/test_ml.py @@ -14,6 +14,9 @@ """Integration tests for firebase_admin.ml module.""" import re +import os +import shutil +import unittest import pytest @@ -25,47 +28,68 @@ # pylint: disable=import-error,no-name-in-module try: import tensorflow as tf - import os # This is only needed for the tensorflow testing - import shutil # This is only needed for the tensorflow testing _TF_ENABLED = True except ImportError: _TF_ENABLED = False +NAME_ONLY_ARGS = { + 'display_name': 'TestModel123' +} +NAME_AND_TAGS_ARGS = { + 'display_name': 'TestModel123_tags', + 'tags': ['test_tag123'] + } +FULL_MODEL_ARGS = { + 'display_name': 'TestModel123_full', + 'tags': ['test_tag567'], + 'file_name': 'model1.tflite' + } +INVALID_FULL_MODEL_ARGS = { + 'display_name': 'TestModel123_invalid_full', + 'tags': ['test_tag890'], + 'file_name': 'invalid_model.tflite' + } + @pytest.fixture -def name_only_model(): - model = ml.Model(display_name="TestModel123") +def firebase_model(request): + args = request.param + tflite_format = None + if args.get('file_name'): + file_path = testutils.resource_filename(args.get('file_name')) + source = ml.TFLiteGCSModelSource.from_tflite_model_file(file_path) + tflite_format = ml.TFLiteFormat(model_source=source) + + ml_model = ml.Model( + display_name=args.get('display_name'), + tags=args.get('tags'), + model_format=tflite_format) + model = ml.create_model(model=ml_model) yield model + _clean_up_model(model) @pytest.fixture -def name_and_tags_model(): - model = ml.Model(display_name="TestModel123_tags", tags=['test_tag123']) - yield model +def model_list(): + ml_model_1 = ml.Model(display_name="TestModel123") + model_1 = ml.create_model(model=ml_model_1) + ml_model_2 = ml.Model(display_name="TestModel123_tags", tags=['test_tag123']) + model_2 = ml.create_model(model=ml_model_2) -@pytest.fixture -def full_model(): - tflite_file_name = testutils.resource_filename('model1.tflite') - source1 = ml.TFLiteGCSModelSource.from_tflite_model_file(tflite_file_name) - format1 = ml.TFLiteFormat(model_source=source1) - model = ml.Model( - display_name="TestModel123_full", - tags=['test_tag567'], - model_format=format1) - yield model + yield [model_1, model_2] + _clean_up_model(model_1) + _clean_up_model(model_2) -@pytest.fixture -def invalid_full_model(): - tflite_file_name = testutils.resource_filename('invalid_model.tflite') - source1 = ml.TFLiteGCSModelSource.from_tflite_model_file(tflite_file_name) - format1 = ml.TFLiteFormat(model_source=source1) - model = ml.Model( - display_name="TestModel123_invalid_full", - tags=['test_tag890'], - model_format=format1) - yield model + +def _clean_up_model(model): + try: + # Try to delete the model. + # Some tests delete the model as part of the test. + ml.delete_model(model.model_id) + except exceptions.NotFoundError: + pass # For rpc errors @@ -85,35 +109,10 @@ def check_operation_error(excinfo, msg): assert str(err) == msg -def _ensure_model_exists(model): - # Delete any previously existing model with the same name because - # it may be modified from the model that is passed in. - _delete_if_exists(model) - - # And recreate using the model passed in - created_model = ml.create_model(model=model) - return created_model - - -# Use this when you know the model_id and are sure it exists. -def _clean_up_model(model): - ml.delete_model(model.model_id) - - -# Use this when you don't know the model_id or it may not exist. -def _delete_if_exists(model): - filter_str = 'displayName={0}'.format(model.display_name) - models_list = ml.list_models(list_filter=filter_str) - for mdl in models_list.models: - ml.delete_model(mdl.model_id) - - -def test_create_simple_model(name_and_tags_model): - _delete_if_exists(name_and_tags_model) - - firebase_model = ml.create_model(model=name_and_tags_model) - assert firebase_model.display_name == name_and_tags_model.display_name - assert firebase_model.tags == name_and_tags_model.tags +@pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True) +def test_create_simple_model(firebase_model): + assert firebase_model.display_name == NAME_AND_TAGS_ARGS.get('display_name') + assert firebase_model.tags == NAME_AND_TAGS_ARGS.get('tags') assert firebase_model.model_id is not None assert firebase_model.create_time is not None assert firebase_model.update_time is not None @@ -122,17 +121,15 @@ def test_create_simple_model(name_and_tags_model): assert firebase_model.published is False assert firebase_model.etag is not None assert firebase_model.model_hash is None + assert firebase_model.model_format is None - _clean_up_model(firebase_model) - -def test_create_full_model(full_model): - _delete_if_exists(full_model) - firebase_model = ml.create_model(model=full_model) - assert firebase_model.display_name == full_model.display_name - assert firebase_model.tags == full_model.tags +@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) +def test_create_full_model(firebase_model): + assert firebase_model.display_name == FULL_MODEL_ARGS.get('display_name') + assert firebase_model.tags == FULL_MODEL_ARGS.get('tags') assert firebase_model.model_format.size_bytes is not None - assert firebase_model.model_format.model_source == full_model.model_format.model_source + assert firebase_model.model_format.model_source.gcs_tflite_uri is not None assert firebase_model.model_id is not None assert firebase_model.create_time is not None assert firebase_model.update_time is not None @@ -142,26 +139,22 @@ def test_create_full_model(full_model): assert firebase_model.etag is not None assert firebase_model.model_hash is not None - _clean_up_model(firebase_model) - -def test_create_already_existing_fails(full_model): - _ensure_model_exists(full_model) +@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) +def test_create_already_existing_fails(firebase_model): with pytest.raises(exceptions.AlreadyExistsError) as excinfo: - ml.create_model(model=full_model) + ml.create_model(model=firebase_model) check_operation_error( excinfo, - 'Model \'{0}\' already exists'.format(full_model.display_name)) - + 'Model \'{0}\' already exists'.format(firebase_model.display_name)) -def test_create_invalid_model(invalid_full_model): - _delete_if_exists(invalid_full_model) - firebase_model = ml.create_model(model=invalid_full_model) - assert firebase_model.display_name == invalid_full_model.display_name - assert firebase_model.tags == invalid_full_model.tags +@pytest.mark.parametrize('firebase_model', [INVALID_FULL_MODEL_ARGS], indirect=True) +def test_create_invalid_model(firebase_model): + assert firebase_model.display_name == INVALID_FULL_MODEL_ARGS.get('display_name') + assert firebase_model.tags == INVALID_FULL_MODEL_ARGS.get('tags') assert firebase_model.model_format.size_bytes is None - assert firebase_model.model_format.model_source == invalid_full_model.model_format.model_source + assert firebase_model.model_format.model_source.gcs_tflite_uri is not None assert firebase_model.model_id is not None assert firebase_model.create_time is not None assert firebase_model.update_time is not None @@ -171,132 +164,120 @@ def test_create_invalid_model(invalid_full_model): assert firebase_model.etag is not None assert firebase_model.model_hash is None - _clean_up_model(firebase_model) -def test_get_model(name_only_model): - existing_model = _ensure_model_exists(name_only_model) +@pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True) +def test_get_model(firebase_model): + get_model = ml.get_model(firebase_model.model_id) + assert get_model.display_name == firebase_model.display_name + assert get_model.tags == firebase_model.tags + assert get_model.model_id is not None + assert get_model.create_time is not None + assert get_model.update_time is not None + assert get_model.validation_error == 'No model file has been uploaded.' + assert get_model.etag is not None + assert get_model.locked is False + assert get_model.published is False + assert get_model.model_hash is None - firebase_model = ml.get_model(existing_model.model_id) - assert firebase_model.display_name == name_only_model.display_name - assert firebase_model.model_id is not None - assert firebase_model.create_time is not None - assert firebase_model.update_time is not None - assert firebase_model.validation_error == 'No model file has been uploaded.' - assert firebase_model.etag is not None - assert firebase_model.locked is False - assert firebase_model.published is False - assert firebase_model.model_hash is None - _clean_up_model(firebase_model) - - -def test_get_non_existing_model(name_only_model): +@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) +def test_get_non_existing_model(firebase_model): # Get a valid model_id that no longer exists - model = _ensure_model_exists(name_only_model) - ml.delete_model(model.model_id) + ml.delete_model(firebase_model.model_id) with pytest.raises(exceptions.NotFoundError) as excinfo: - ml.get_model(model.model_id) + ml.get_model(firebase_model.model_id) check_firebase_error(excinfo, 404, 'Requested entity was not found.') -def test_update_model(name_only_model): +@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) +def test_update_model(firebase_model): new_model_name = 'TestModel123_updated' - _delete_if_exists(ml.Model(display_name=new_model_name)) - existing_model = _ensure_model_exists(name_only_model) - existing_model.display_name = new_model_name - - firebase_model = ml.update_model(existing_model) - assert firebase_model.display_name == new_model_name - assert firebase_model.model_id == existing_model.model_id - assert firebase_model.create_time == existing_model.create_time - assert firebase_model.update_time != existing_model.update_time - assert firebase_model.validation_error == existing_model.validation_error - assert firebase_model.etag != existing_model.etag - assert firebase_model.published == existing_model.published - assert firebase_model.locked == existing_model.locked + firebase_model.display_name = new_model_name + + updated_model = ml.update_model(firebase_model) + assert updated_model.display_name == new_model_name + assert updated_model.model_id == firebase_model.model_id + assert updated_model.create_time == firebase_model.create_time + assert updated_model.update_time != firebase_model.update_time + assert updated_model.validation_error == firebase_model.validation_error + assert updated_model.etag != firebase_model.etag + assert updated_model.published == firebase_model.published + assert updated_model.locked == firebase_model.locked # Second call with same model does not cause error - firebase_model2 = ml.update_model(firebase_model) - assert firebase_model2.display_name == firebase_model.display_name - assert firebase_model2.model_id == firebase_model.model_id - assert firebase_model2.create_time == firebase_model.create_time - assert firebase_model2.update_time != firebase_model.update_time - assert firebase_model2.validation_error == firebase_model.validation_error - assert firebase_model2.etag != existing_model.etag - assert firebase_model2.published == firebase_model.published - assert firebase_model2.locked == firebase_model.locked - - _clean_up_model(firebase_model) - - -def test_update_non_existing_model(name_only_model): - model = _ensure_model_exists(name_only_model) - ml.delete_model(model.model_id) - - model.tags = ['tag987'] + updated_model2 = ml.update_model(updated_model) + assert updated_model2.display_name == updated_model.display_name + assert updated_model2.model_id == updated_model.model_id + assert updated_model2.create_time == updated_model.create_time + assert updated_model2.update_time != updated_model.update_time + assert updated_model2.validation_error == updated_model.validation_error + assert updated_model2.etag != updated_model.etag + assert updated_model2.published == updated_model.published + assert updated_model2.locked == updated_model.locked + + +@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) +def test_update_non_existing_model(firebase_model): + ml.delete_model(firebase_model.model_id) + + firebase_model.tags = ['tag987'] with pytest.raises(exceptions.NotFoundError) as excinfo: - ml.update_model(model) + ml.update_model(firebase_model) check_operation_error( excinfo, - 'Model \'{0}\' was not found'.format(model.as_dict().get('name'))) + 'Model \'{0}\' was not found'.format(firebase_model.as_dict().get('name'))) -def test_publish_unpublish_model(full_model): - model = _ensure_model_exists(full_model) - assert model.published is False - published_model = ml.publish_model(model.model_id) +@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) +def test_publish_unpublish_model(firebase_model): + assert firebase_model.published is False + + published_model = ml.publish_model(firebase_model.model_id) assert published_model.published is True unpublished_model = ml.unpublish_model(published_model.model_id) assert unpublished_model.published is False - _clean_up_model(unpublished_model) - -def test_publish_invalid_fails(name_only_model): - model = _ensure_model_exists(name_only_model) - assert model.validation_error is not None +@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) +def test_publish_invalid_fails(firebase_model): + assert firebase_model.validation_error is not None with pytest.raises(exceptions.FailedPreconditionError) as excinfo: - ml.publish_model(model.model_id) + ml.publish_model(firebase_model.model_id) check_operation_error( excinfo, 'Cannot publish a model that is not verified.') -def test_publish_unpublish_non_existing_model(full_model): - model = _ensure_model_exists(full_model) - ml.delete_model(model.model_id) +@pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) +def test_publish_unpublish_non_existing_model(firebase_model): + ml.delete_model(firebase_model.model_id) with pytest.raises(exceptions.NotFoundError) as excinfo: - ml.publish_model(model.model_id) + ml.publish_model(firebase_model.model_id) check_operation_error( excinfo, - 'Model \'{0}\' was not found'.format(model.as_dict().get('name'))) + 'Model \'{0}\' was not found'.format(firebase_model.as_dict().get('name'))) with pytest.raises(exceptions.NotFoundError) as excinfo: - ml.unpublish_model(model.model_id) + ml.unpublish_model(firebase_model.model_id) check_operation_error( excinfo, - 'Model \'{0}\' was not found'.format(model.as_dict().get('name'))) + 'Model \'{0}\' was not found'.format(firebase_model.as_dict().get('name'))) -def test_list_models(name_only_model, name_and_tags_model): - existing_model1 = _ensure_model_exists(name_only_model) - existing_model2 = _ensure_model_exists(name_and_tags_model) +def test_list_models(model_list): filter_str = 'displayName={0} OR tags:{1}'.format( - existing_model1.display_name, existing_model2.tags[0]) + model_list[0].display_name, model_list[1].tags[0]) models_list = ml.list_models(list_filter=filter_str) assert len(models_list.models) == 2 for mdl in models_list.models: - assert mdl == existing_model1 or mdl == existing_model2 + assert mdl == model_list[0] or mdl == model_list[1] assert models_list.models[0] != models_list.models[1] - _clean_up_model(existing_model1) - _clean_up_model(existing_model2) - def test_list_models_invalid_filter(): invalid_filter = 'InvalidFilterParam=123' @@ -306,84 +287,94 @@ def test_list_models_invalid_filter(): check_firebase_error(excinfo, 400, 'Request contains an invalid argument.') -def test_delete_model(name_only_model): - existing_model = _ensure_model_exists(name_only_model) - - ml.delete_model(existing_model.model_id) +@pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) +def test_delete_model(firebase_model): + ml.delete_model(firebase_model.model_id) # Second delete of same model will fail with pytest.raises(exceptions.NotFoundError) as excinfo: - ml.delete_model(existing_model.model_id) + ml.delete_model(firebase_model.model_id) check_firebase_error(excinfo, 404, 'Requested entity was not found.') +# Test tensor flow conversion functions if tensor flow is enabled. #'pip install tensorflow' in the environment if you want _TF_ENABLED = True -#'pip install tensorflow=2.0.0' for version 2 etc. -if _TF_ENABLED: - # Test tensor flow conversion functions if tensor flow is enabled. - SAVED_MODEL_DIR = '/tmp/saved_model/1' - - def _clean_up_tmp_directory(): - if os.path.exists(SAVED_MODEL_DIR): - shutil.rmtree(SAVED_MODEL_DIR) - - @pytest.fixture - def keras_model(): - x_array = [-1, 0, 1, 2, 3, 4] - y_array = [-3, -1, 1, 3, 5, 7] - model = tf.keras.models.Sequential( - [tf.keras.layers.Dense(units=1, input_shape=[1])]) - model.compile(optimizer='sgd', loss='mean_squared_error') - model.fit(x_array, y_array, epochs=3) - yield model - - @pytest.fixture - def saved_model_dir(keras_model): - # different versions have different model conversion capability - # pick something that works for each version - save_dir = SAVED_MODEL_DIR - _clean_up_tmp_directory() # previous failures may leave files - if tf.version.VERSION.startswith('1.'): - tf.reset_default_graph() - x_var = tf.placeholder(tf.float32, (None, 3), name="x") - y_var = tf.multiply(x_var, x_var, name="y") - with tf.Session() as sess: - tf.saved_model.simple_save(sess, save_dir, {"x": x_var}, {"y": y_var}) - else: - # If it's not version 1.x or version 2.x we need to update the test. - assert tf.version.VERSION.startswith('2.') - tf.saved_model.save(keras_model, save_dir) - yield save_dir - - - def test_from_keras_model(keras_model): - source1 = ml.TFLiteGCSModelSource.from_keras_model(keras_model, 'model2.tflite') - assert re.search( - '^gs://.*/Firebase/ML/Models/model2.tflite$', - source1.gcs_tflite_uri) is not None - - # Validate the conversion by creating a model - format1 = ml.TFLiteFormat(model_source=source1) - model1 = ml.Model(display_name="KerasModel1", model_format=format1) - firebase_model = ml.create_model(model1) - assert firebase_model.model_id is not None - assert firebase_model.validation_error is None - - _clean_up_model(firebase_model) - - def test_from_saved_model(saved_model_dir): - # Test the conversion helper - source1 = ml.TFLiteGCSModelSource.from_saved_model(saved_model_dir, 'model3.tflite') - assert re.search( - '^gs://.*/Firebase/ML/Models/model3.tflite$', - source1.gcs_tflite_uri) is not None - - # Validate the conversion by creating a model - format1 = ml.TFLiteFormat(model_source=source1) - model1 = ml.Model(display_name="SavedModel1", model_format=format1) - firebase_model = ml.create_model(model1) - assert firebase_model.model_id is not None - assert firebase_model.validation_error is None - - _clean_up_model(firebase_model) - _clean_up_tmp_directory() +#'pip install tensorflow==2.0.0b' for version 2 etc. + + +SAVED_MODEL_DIR = '/tmp/saved_model/1' + + +def _clean_up_tmp_directory(): + if os.path.exists(SAVED_MODEL_DIR): + shutil.rmtree(SAVED_MODEL_DIR) + + +@pytest.fixture +def keras_model(): + assert _TF_ENABLED + x_array = [-1, 0, 1, 2, 3, 4] + y_array = [-3, -1, 1, 3, 5, 7] + model = tf.keras.models.Sequential( + [tf.keras.layers.Dense(units=1, input_shape=[1])]) + model.compile(optimizer='sgd', loss='mean_squared_error') + model.fit(x_array, y_array, epochs=3) + return model + + +@pytest.fixture +def saved_model_dir(keras_model): + assert _TF_ENABLED + # different versions have different model conversion capability + # pick something that works for each version + save_dir = SAVED_MODEL_DIR + _clean_up_tmp_directory() # previous failures may leave files + if tf.version.VERSION.startswith('1.'): + tf.reset_default_graph() + x_var = tf.placeholder(tf.float32, (None, 3), name="x") + y_var = tf.multiply(x_var, x_var, name="y") + with tf.Session() as sess: + tf.saved_model.simple_save(sess, save_dir, {"x": x_var}, {"y": y_var}) + else: + # If it's not version 1.x or version 2.x we need to update the test. + assert tf.version.VERSION.startswith('2.') + tf.saved_model.save(keras_model, save_dir) + yield save_dir + _clean_up_tmp_directory() + + +@unittest.skipUnless(_TF_ENABLED, 'Tensor flow is required for this test.') +def test_from_keras_model(keras_model): + source = ml.TFLiteGCSModelSource.from_keras_model(keras_model, 'model2.tflite') + assert re.search( + '^gs://.*/Firebase/ML/Models/model2.tflite$', + source.gcs_tflite_uri) is not None + + # Validate the conversion by creating a model + try: + model_format = ml.TFLiteFormat(model_source=source) + model = ml.Model(display_name="KerasModel1", model_format=model_format) + created_model = ml.create_model(model) + assert created_model.model_id is not None + assert created_model.validation_error is None + finally: + _clean_up_model(created_model) + + +@unittest.skipUnless(_TF_ENABLED, 'Tensor flow is required for this test.') +def test_from_saved_model(saved_model_dir): + # Test the conversion helper + source = ml.TFLiteGCSModelSource.from_saved_model(saved_model_dir, 'model3.tflite') + assert re.search( + '^gs://.*/Firebase/ML/Models/model3.tflite$', + source.gcs_tflite_uri) is not None + + # Validate the conversion by creating a model + try: + model_format = ml.TFLiteFormat(model_source=source) + model = ml.Model(display_name="SavedModel1", model_format=model_format) + created_model = ml.create_model(model) + assert created_model.model_id is not None + assert created_model.validation_error is None + finally: + _clean_up_model(created_model) From 0bc734cfd975d417dff7e4ea68b191d262d148f4 Mon Sep 17 00:00:00 2001 From: ifielker Date: Wed, 29 Jan 2020 18:36:39 -0500 Subject: [PATCH 3/4] review suggestions #2 --- integration/test_ml.py | 171 ++++++++++++++++++----------------------- 1 file changed, 76 insertions(+), 95 deletions(-) diff --git a/integration/test_ml.py b/integration/test_ml.py index 53e4abcc5..48190f954 100644 --- a/integration/test_ml.py +++ b/integration/test_ml.py @@ -16,7 +16,8 @@ import re import os import shutil -import unittest +import random +import tempfile import pytest @@ -34,29 +35,33 @@ NAME_ONLY_ARGS = { - 'display_name': 'TestModel123' + 'display_name': 'TestModel123_{0}'.format(random.randint(1111, 9999)) +} +NAME_ONLY_ARGS_UPDATED = { + 'display_name': 'TestModel123_updated_{0}'.format(random.randint(1111, 9999)) } NAME_AND_TAGS_ARGS = { - 'display_name': 'TestModel123_tags', + 'display_name': 'TestModel123_tags_{0}'.format(random.randint(1111, 9999)), 'tags': ['test_tag123'] - } +} FULL_MODEL_ARGS = { - 'display_name': 'TestModel123_full', + 'display_name': 'TestModel123_full_{0}'.format(random.randint(1111, 9999)), 'tags': ['test_tag567'], 'file_name': 'model1.tflite' - } +} INVALID_FULL_MODEL_ARGS = { - 'display_name': 'TestModel123_invalid_full', + 'display_name': 'TestModel123_invalid_full_{0}'.format(random.randint(1111, 9999)), 'tags': ['test_tag890'], 'file_name': 'invalid_model.tflite' - } +} @pytest.fixture def firebase_model(request): args = request.param tflite_format = None - if args.get('file_name'): - file_path = testutils.resource_filename(args.get('file_name')) + file_name = args.get('file_name') + if file_name: + file_path = testutils.resource_filename(file_name) source = ml.TFLiteGCSModelSource.from_tflite_model_file(file_path) tflite_format = ml.TFLiteFormat(model_source=source) @@ -109,35 +114,44 @@ def check_operation_error(excinfo, msg): assert str(err) == msg +def check_model(model, args): + assert model.display_name == args.get('display_name') + assert model.tags == args.get('tags') + assert model.model_id is not None + assert model.create_time is not None + assert model.update_time is not None + assert model.locked is False + assert model.etag is not None + + +def check_model_format(model, has_model_format, validation_error): + if has_model_format: + assert model.validation_error == validation_error + assert model.published is False + assert model.model_format.model_source.gcs_tflite_uri.startswith('gs://') + if validation_error: + assert model.model_format.size_bytes is None + assert model.model_hash is None + else: + assert model.model_format.size_bytes is not None + assert model.model_hash is not None + else: + assert model.model_format is None + assert model.validation_error == 'No model file has been uploaded.' + assert model.published is False + assert model.model_hash is None + + @pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True) def test_create_simple_model(firebase_model): - assert firebase_model.display_name == NAME_AND_TAGS_ARGS.get('display_name') - assert firebase_model.tags == NAME_AND_TAGS_ARGS.get('tags') - assert firebase_model.model_id is not None - assert firebase_model.create_time is not None - assert firebase_model.update_time is not None - assert firebase_model.validation_error == 'No model file has been uploaded.' - assert firebase_model.locked is False - assert firebase_model.published is False - assert firebase_model.etag is not None - assert firebase_model.model_hash is None - assert firebase_model.model_format is None + check_model(firebase_model, NAME_AND_TAGS_ARGS) + check_model_format(firebase_model, False, None) @pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) def test_create_full_model(firebase_model): - assert firebase_model.display_name == FULL_MODEL_ARGS.get('display_name') - assert firebase_model.tags == FULL_MODEL_ARGS.get('tags') - assert firebase_model.model_format.size_bytes is not None - assert firebase_model.model_format.model_source.gcs_tflite_uri is not None - assert firebase_model.model_id is not None - assert firebase_model.create_time is not None - assert firebase_model.update_time is not None - assert firebase_model.validation_error is None - assert firebase_model.locked is False - assert firebase_model.published is False - assert firebase_model.etag is not None - assert firebase_model.model_hash is not None + check_model(firebase_model, FULL_MODEL_ARGS) + check_model_format(firebase_model, True, None) @pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) @@ -151,33 +165,15 @@ def test_create_already_existing_fails(firebase_model): @pytest.mark.parametrize('firebase_model', [INVALID_FULL_MODEL_ARGS], indirect=True) def test_create_invalid_model(firebase_model): - assert firebase_model.display_name == INVALID_FULL_MODEL_ARGS.get('display_name') - assert firebase_model.tags == INVALID_FULL_MODEL_ARGS.get('tags') - assert firebase_model.model_format.size_bytes is None - assert firebase_model.model_format.model_source.gcs_tflite_uri is not None - assert firebase_model.model_id is not None - assert firebase_model.create_time is not None - assert firebase_model.update_time is not None - assert firebase_model.validation_error == 'Invalid flatbuffer format' - assert firebase_model.locked is False - assert firebase_model.published is False - assert firebase_model.etag is not None - assert firebase_model.model_hash is None + check_model(firebase_model, INVALID_FULL_MODEL_ARGS) + check_model_format(firebase_model, True, 'Invalid flatbuffer format') @pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True) def test_get_model(firebase_model): get_model = ml.get_model(firebase_model.model_id) - assert get_model.display_name == firebase_model.display_name - assert get_model.tags == firebase_model.tags - assert get_model.model_id is not None - assert get_model.create_time is not None - assert get_model.update_time is not None - assert get_model.validation_error == 'No model file has been uploaded.' - assert get_model.etag is not None - assert get_model.locked is False - assert get_model.published is False - assert get_model.model_hash is None + check_model(get_model, NAME_AND_TAGS_ARGS) + check_model_format(get_model, False, None) @pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) @@ -192,29 +188,16 @@ def test_get_non_existing_model(firebase_model): @pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) def test_update_model(firebase_model): - new_model_name = 'TestModel123_updated' + new_model_name = NAME_ONLY_ARGS_UPDATED.get('display_name') firebase_model.display_name = new_model_name - updated_model = ml.update_model(firebase_model) - assert updated_model.display_name == new_model_name - assert updated_model.model_id == firebase_model.model_id - assert updated_model.create_time == firebase_model.create_time - assert updated_model.update_time != firebase_model.update_time - assert updated_model.validation_error == firebase_model.validation_error - assert updated_model.etag != firebase_model.etag - assert updated_model.published == firebase_model.published - assert updated_model.locked == firebase_model.locked + check_model(updated_model, NAME_ONLY_ARGS_UPDATED) + check_model_format(updated_model, False, None) # Second call with same model does not cause error updated_model2 = ml.update_model(updated_model) - assert updated_model2.display_name == updated_model.display_name - assert updated_model2.model_id == updated_model.model_id - assert updated_model2.create_time == updated_model.create_time - assert updated_model2.update_time != updated_model.update_time - assert updated_model2.validation_error == updated_model.validation_error - assert updated_model2.etag != updated_model.etag - assert updated_model2.published == updated_model.published - assert updated_model2.locked == updated_model.locked + check_model(updated_model2, NAME_ONLY_ARGS_UPDATED) + check_model_format(updated_model2, False, None) @pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) @@ -272,11 +255,10 @@ def test_list_models(model_list): filter_str = 'displayName={0} OR tags:{1}'.format( model_list[0].display_name, model_list[1].tags[0]) - models_list = ml.list_models(list_filter=filter_str) - assert len(models_list.models) == 2 - for mdl in models_list.models: - assert mdl == model_list[0] or mdl == model_list[1] - assert models_list.models[0] != models_list.models[1] + all_models = ml.list_models(list_filter=filter_str) + all_model_ids = [mdl.model_id for mdl in all_models.iterate_all()] + for mdl in model_list: + assert mdl.model_id in all_model_ids def test_list_models_invalid_filter(): @@ -302,12 +284,9 @@ def test_delete_model(firebase_model): #'pip install tensorflow==2.0.0b' for version 2 etc. -SAVED_MODEL_DIR = '/tmp/saved_model/1' - - -def _clean_up_tmp_directory(): - if os.path.exists(SAVED_MODEL_DIR): - shutil.rmtree(SAVED_MODEL_DIR) +def _clean_up_directory(save_dir): + if save_dir.startswith(tempfile.gettempdir()) and os.path.exists(save_dir): + shutil.rmtree(save_dir) @pytest.fixture @@ -327,8 +306,8 @@ def saved_model_dir(keras_model): assert _TF_ENABLED # different versions have different model conversion capability # pick something that works for each version - save_dir = SAVED_MODEL_DIR - _clean_up_tmp_directory() # previous failures may leave files + parent = tempfile.mkdtemp() + save_dir = os.path.join(parent, 'child') if tf.version.VERSION.startswith('1.'): tf.reset_default_graph() x_var = tf.placeholder(tf.float32, (None, 3), name="x") @@ -340,10 +319,10 @@ def saved_model_dir(keras_model): assert tf.version.VERSION.startswith('2.') tf.saved_model.save(keras_model, save_dir) yield save_dir - _clean_up_tmp_directory() + _clean_up_directory(parent) -@unittest.skipUnless(_TF_ENABLED, 'Tensor flow is required for this test.') +@pytest.mark.skipif(not _TF_ENABLED, reason='Tensor flow is required for this test.') def test_from_keras_model(keras_model): source = ml.TFLiteGCSModelSource.from_keras_model(keras_model, 'model2.tflite') assert re.search( @@ -351,17 +330,18 @@ def test_from_keras_model(keras_model): source.gcs_tflite_uri) is not None # Validate the conversion by creating a model + model_format = ml.TFLiteFormat(model_source=source) + model = ml.Model(display_name="KerasModel1", model_format=model_format) + created_model = ml.create_model(model) + try: - model_format = ml.TFLiteFormat(model_source=source) - model = ml.Model(display_name="KerasModel1", model_format=model_format) - created_model = ml.create_model(model) assert created_model.model_id is not None assert created_model.validation_error is None finally: _clean_up_model(created_model) -@unittest.skipUnless(_TF_ENABLED, 'Tensor flow is required for this test.') +@pytest.mark.skipif(not _TF_ENABLED, reason='Tensor flow is required for this test.') def test_from_saved_model(saved_model_dir): # Test the conversion helper source = ml.TFLiteGCSModelSource.from_saved_model(saved_model_dir, 'model3.tflite') @@ -370,10 +350,11 @@ def test_from_saved_model(saved_model_dir): source.gcs_tflite_uri) is not None # Validate the conversion by creating a model + model_format = ml.TFLiteFormat(model_source=source) + model = ml.Model(display_name="SavedModel1", model_format=model_format) + created_model = ml.create_model(model) + try: - model_format = ml.TFLiteFormat(model_source=source) - model = ml.Model(display_name="SavedModel1", model_format=model_format) - created_model = ml.create_model(model) assert created_model.model_id is not None assert created_model.validation_error is None finally: From ae9527f53180bed2005d211c5af70415b37bc05e Mon Sep 17 00:00:00 2001 From: ifielker Date: Thu, 30 Jan 2020 12:32:27 -0500 Subject: [PATCH 4/4] review suggestions #3 --- integration/test_ml.py | 56 +++++++++++++++++++++++++----------------- 1 file changed, 34 insertions(+), 22 deletions(-) diff --git a/integration/test_ml.py b/integration/test_ml.py index 48190f954..4c44f3d10 100644 --- a/integration/test_ml.py +++ b/integration/test_ml.py @@ -13,16 +13,17 @@ # limitations under the License. """Integration tests for firebase_admin.ml module.""" -import re import os -import shutil import random +import re +import shutil +import string import tempfile import pytest -from firebase_admin import ml from firebase_admin import exceptions +from firebase_admin import ml from tests import testutils @@ -34,27 +35,34 @@ _TF_ENABLED = False +def _random_identifier(prefix): + #pylint: disable=unused-variable + suffix = ''.join([random.choice(string.ascii_letters + string.digits) for n in range(8)]) + return '{0}_{1}'.format(prefix, suffix) + + NAME_ONLY_ARGS = { - 'display_name': 'TestModel123_{0}'.format(random.randint(1111, 9999)) + 'display_name': _random_identifier('TestModel123_') } NAME_ONLY_ARGS_UPDATED = { - 'display_name': 'TestModel123_updated_{0}'.format(random.randint(1111, 9999)) + 'display_name': _random_identifier('TestModel123_updated_') } NAME_AND_TAGS_ARGS = { - 'display_name': 'TestModel123_tags_{0}'.format(random.randint(1111, 9999)), + 'display_name': _random_identifier('TestModel123_tags_'), 'tags': ['test_tag123'] } FULL_MODEL_ARGS = { - 'display_name': 'TestModel123_full_{0}'.format(random.randint(1111, 9999)), + 'display_name': _random_identifier('TestModel123_full_'), 'tags': ['test_tag567'], 'file_name': 'model1.tflite' } INVALID_FULL_MODEL_ARGS = { - 'display_name': 'TestModel123_invalid_full_{0}'.format(random.randint(1111, 9999)), + 'display_name': _random_identifier('TestModel123_invalid_full_'), 'tags': ['test_tag890'], 'file_name': 'invalid_model.tflite' } + @pytest.fixture def firebase_model(request): args = request.param @@ -76,10 +84,11 @@ def firebase_model(request): @pytest.fixture def model_list(): - ml_model_1 = ml.Model(display_name="TestModel123") + ml_model_1 = ml.Model(display_name=_random_identifier('TestModel123_list1_')) model_1 = ml.create_model(model=ml_model_1) - ml_model_2 = ml.Model(display_name="TestModel123_tags", tags=['test_tag123']) + ml_model_2 = ml.Model(display_name=_random_identifier('TestModel123_list2_'), + tags=['test_tag123']) model_2 = ml.create_model(model=ml_model_2) yield [model_1, model_2] @@ -124,7 +133,7 @@ def check_model(model, args): assert model.etag is not None -def check_model_format(model, has_model_format, validation_error): +def check_model_format(model, has_model_format=False, validation_error=None): if has_model_format: assert model.validation_error == validation_error assert model.published is False @@ -145,13 +154,13 @@ def check_model_format(model, has_model_format, validation_error): @pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True) def test_create_simple_model(firebase_model): check_model(firebase_model, NAME_AND_TAGS_ARGS) - check_model_format(firebase_model, False, None) + check_model_format(firebase_model) @pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) def test_create_full_model(firebase_model): check_model(firebase_model, FULL_MODEL_ARGS) - check_model_format(firebase_model, True, None) + check_model_format(firebase_model, True) @pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) @@ -173,7 +182,7 @@ def test_create_invalid_model(firebase_model): def test_get_model(firebase_model): get_model = ml.get_model(firebase_model.model_id) check_model(get_model, NAME_AND_TAGS_ARGS) - check_model_format(get_model, False, None) + check_model_format(get_model) @pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) @@ -192,12 +201,12 @@ def test_update_model(firebase_model): firebase_model.display_name = new_model_name updated_model = ml.update_model(firebase_model) check_model(updated_model, NAME_ONLY_ARGS_UPDATED) - check_model_format(updated_model, False, None) + check_model_format(updated_model) # Second call with same model does not cause error updated_model2 = ml.update_model(updated_model) check_model(updated_model2, NAME_ONLY_ARGS_UPDATED) - check_model_format(updated_model2, False, None) + check_model_format(updated_model2) @pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) @@ -304,10 +313,13 @@ def keras_model(): @pytest.fixture def saved_model_dir(keras_model): assert _TF_ENABLED - # different versions have different model conversion capability - # pick something that works for each version + # Make a new parent directory. The child directory must not exist yet. + # The child directory gets created by tf. If it exists, the tf call fails. parent = tempfile.mkdtemp() save_dir = os.path.join(parent, 'child') + + # different versions have different model conversion capability + # pick something that works for each version if tf.version.VERSION.startswith('1.'): tf.reset_default_graph() x_var = tf.placeholder(tf.float32, (None, 3), name="x") @@ -331,12 +343,12 @@ def test_from_keras_model(keras_model): # Validate the conversion by creating a model model_format = ml.TFLiteFormat(model_source=source) - model = ml.Model(display_name="KerasModel1", model_format=model_format) + model = ml.Model(display_name=_random_identifier('KerasModel_'), model_format=model_format) created_model = ml.create_model(model) try: - assert created_model.model_id is not None - assert created_model.validation_error is None + check_model(created_model, {'display_name': model.display_name}) + check_model_format(created_model, True) finally: _clean_up_model(created_model) @@ -351,7 +363,7 @@ def test_from_saved_model(saved_model_dir): # Validate the conversion by creating a model model_format = ml.TFLiteFormat(model_source=source) - model = ml.Model(display_name="SavedModel1", model_format=model_format) + model = ml.Model(display_name=_random_identifier('SavedModel_'), model_format=model_format) created_model = ml.create_model(model) try: