From 29aab680276addfce154fa042aad36ab7ac9843f Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 14 Apr 2020 18:01:28 -0400 Subject: [PATCH 1/3] add Headers --- firebase_admin/ml.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index 06429b5a1..58d3e9aab 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -27,6 +27,7 @@ import requests +import firebase_admin from firebase_admin import _http_client from firebase_admin import _utils from firebase_admin import exceptions @@ -785,6 +786,9 @@ def __init__(self, app): self._project_url = _MLService.PROJECT_URL.format(self._project_id) self._client = _http_client.JsonHttpClient( credential=app.credential.get_credential(), + headers={ + 'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__), + }, base_url=self._project_url) self._operation_client = _http_client.JsonHttpClient( credential=app.credential.get_credential(), From 325d10c56d614492240bcd328aa061d9070513d5 Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 14 Apr 2020 18:20:18 -0400 Subject: [PATCH 2/3] add tests --- firebase_admin/ml.py | 3 +++ tests/test_ml.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index 58d3e9aab..b65cd7d37 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -792,6 +792,9 @@ def __init__(self, app): base_url=self._project_url) self._operation_client = _http_client.JsonHttpClient( credential=app.credential.get_credential(), + headers={ + 'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__), + }, base_url=_MLService.OPERATION_URL) def get_operation(self, op_name): diff --git a/tests/test_ml.py b/tests/test_ml.py index 8813792e6..46f36e50a 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -25,6 +25,8 @@ BASE_URL = 'https://firebaseml.googleapis.com/v1beta2/' +HEADER_CLIENT_KEY = 'X-FIREBASE-CLIENT' +HEADER_CLIENT_VALUE = 'fire-admin-python/3.2.1' PROJECT_ID = 'my-project-1' PAGE_TOKEN = 'pageToken' @@ -536,6 +538,7 @@ def test_wait_for_unlocked(self): assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == TestModel._op_url(PROJECT_ID) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE def test_wait_for_unlocked_timeout(self): recorder = instrument_ml_service( @@ -589,8 +592,10 @@ def test_returns_locked(self): assert len(recorder) == 2 assert recorder[0].method == 'POST' assert recorder[0].url == TestCreateModel._url(PROJECT_ID) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE assert recorder[1].method == 'GET' assert recorder[1].url == TestCreateModel._get_url(PROJECT_ID, MODEL_ID_1) + assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE def test_operation_error(self): instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE) @@ -681,8 +686,10 @@ def test_returns_locked(self): assert len(recorder) == 2 assert recorder[0].method == 'PATCH' assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE assert recorder[1].method == 'GET' assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE def test_operation_error(self): instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE) @@ -778,6 +785,7 @@ def test_immediate_done(self, publish_function, published): assert len(recorder) == 1 assert recorder[0].method == 'PATCH' assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE body = json.loads(recorder[0].body.decode()) assert body.get('state', {}).get('published', None) is published @@ -793,8 +801,10 @@ def test_returns_locked(self, publish_function): assert len(recorder) == 2 assert recorder[0].method == 'PATCH' assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE assert recorder[1].method == 'GET' assert recorder[1].url == TestPublishUnpublish._get_url(PROJECT_ID, MODEL_ID_1) + assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_operation_error(self, publish_function): @@ -847,6 +857,7 @@ def test_get_model(self): assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE assert model == MODEL_1 assert model.model_id == MODEL_ID_1 assert model.display_name == DISPLAY_NAME_1 @@ -870,6 +881,7 @@ def test_get_model_error(self): assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE def test_no_project_id(self): def evaluate(): @@ -900,6 +912,7 @@ def test_delete_model(self): assert len(recorder) == 1 assert recorder[0].method == 'DELETE' assert recorder[0].url == TestDeleteModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE @pytest.mark.parametrize('model_id, exc_type', INVALID_MODEL_ID_ARGS) def test_delete_model_validation_errors(self, model_id, exc_type): @@ -920,6 +933,7 @@ def test_delete_model_error(self): assert len(recorder) == 1 assert recorder[0].method == 'DELETE' assert recorder[0].url == self._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE def test_no_project_id(self): def evaluate(): @@ -957,6 +971,7 @@ def test_list_models_no_args(self): assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == TestListModels._url(PROJECT_ID) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE TestListModels._check_page(models_page, 2) assert models_page.has_next_page assert models_page.next_page_token == NEXT_PAGE_TOKEN @@ -975,6 +990,7 @@ def test_list_models_with_all_args(self): TestListModels._url(PROJECT_ID) + '?filter=display_name%3DdisplayName3&page_size=10&page_token={0}' .format(PAGE_TOKEN)) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE assert isinstance(models_page, ml.ListModelsPage) assert len(models_page.models) == 1 assert models_page.models[0] == MODEL_3 @@ -1020,6 +1036,7 @@ def test_list_models_error(self): assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == TestListModels._url(PROJECT_ID) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE def test_no_project_id(self): def evaluate(): From b82a3f7eeed843389bdace14cd07a756acfb2267 Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 14 Apr 2020 19:17:04 -0400 Subject: [PATCH 3/3] review comments --- firebase_admin/ml.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index b65cd7d37..db1657839 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -784,17 +784,16 @@ def __init__(self, app): 'Project ID is required to access ML service. Either set the ' 'projectId option, or use service account credentials.') self._project_url = _MLService.PROJECT_URL.format(self._project_id) + ml_headers = { + 'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__), + } self._client = _http_client.JsonHttpClient( credential=app.credential.get_credential(), - headers={ - 'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__), - }, + headers=ml_headers, base_url=self._project_url) self._operation_client = _http_client.JsonHttpClient( credential=app.credential.get_credential(), - headers={ - 'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__), - }, + headers=ml_headers, base_url=_MLService.OPERATION_URL) def get_operation(self, op_name):