diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index 06429b5a1..db1657839 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -27,6 +27,7 @@ import requests +import firebase_admin from firebase_admin import _http_client from firebase_admin import _utils from firebase_admin import exceptions @@ -783,11 +784,16 @@ def __init__(self, app): 'Project ID is required to access ML service. Either set the ' 'projectId option, or use service account credentials.') self._project_url = _MLService.PROJECT_URL.format(self._project_id) + ml_headers = { + 'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__), + } self._client = _http_client.JsonHttpClient( credential=app.credential.get_credential(), + headers=ml_headers, base_url=self._project_url) self._operation_client = _http_client.JsonHttpClient( credential=app.credential.get_credential(), + headers=ml_headers, base_url=_MLService.OPERATION_URL) def get_operation(self, op_name): diff --git a/tests/test_ml.py b/tests/test_ml.py index 8813792e6..46f36e50a 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -25,6 +25,8 @@ BASE_URL = 'https://firebaseml.googleapis.com/v1beta2/' +HEADER_CLIENT_KEY = 'X-FIREBASE-CLIENT' +HEADER_CLIENT_VALUE = 'fire-admin-python/3.2.1' PROJECT_ID = 'my-project-1' PAGE_TOKEN = 'pageToken' @@ -536,6 +538,7 @@ def test_wait_for_unlocked(self): assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == TestModel._op_url(PROJECT_ID) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE def test_wait_for_unlocked_timeout(self): recorder = instrument_ml_service( @@ -589,8 +592,10 @@ def test_returns_locked(self): assert len(recorder) == 2 assert recorder[0].method == 'POST' assert recorder[0].url == TestCreateModel._url(PROJECT_ID) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE assert recorder[1].method == 'GET' assert recorder[1].url == TestCreateModel._get_url(PROJECT_ID, MODEL_ID_1) + assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE def test_operation_error(self): instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE) @@ -681,8 +686,10 @@ def test_returns_locked(self): assert len(recorder) == 2 assert recorder[0].method == 'PATCH' assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE assert recorder[1].method == 'GET' assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE def test_operation_error(self): instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE) @@ -778,6 +785,7 @@ def test_immediate_done(self, publish_function, published): assert len(recorder) == 1 assert recorder[0].method == 'PATCH' assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE body = json.loads(recorder[0].body.decode()) assert body.get('state', {}).get('published', None) is published @@ -793,8 +801,10 @@ def test_returns_locked(self, publish_function): assert len(recorder) == 2 assert recorder[0].method == 'PATCH' assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE assert recorder[1].method == 'GET' assert recorder[1].url == TestPublishUnpublish._get_url(PROJECT_ID, MODEL_ID_1) + assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_operation_error(self, publish_function): @@ -847,6 +857,7 @@ def test_get_model(self): assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE assert model == MODEL_1 assert model.model_id == MODEL_ID_1 assert model.display_name == DISPLAY_NAME_1 @@ -870,6 +881,7 @@ def test_get_model_error(self): assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE def test_no_project_id(self): def evaluate(): @@ -900,6 +912,7 @@ def test_delete_model(self): assert len(recorder) == 1 assert recorder[0].method == 'DELETE' assert recorder[0].url == TestDeleteModel._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE @pytest.mark.parametrize('model_id, exc_type', INVALID_MODEL_ID_ARGS) def test_delete_model_validation_errors(self, model_id, exc_type): @@ -920,6 +933,7 @@ def test_delete_model_error(self): assert len(recorder) == 1 assert recorder[0].method == 'DELETE' assert recorder[0].url == self._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE def test_no_project_id(self): def evaluate(): @@ -957,6 +971,7 @@ def test_list_models_no_args(self): assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == TestListModels._url(PROJECT_ID) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE TestListModels._check_page(models_page, 2) assert models_page.has_next_page assert models_page.next_page_token == NEXT_PAGE_TOKEN @@ -975,6 +990,7 @@ def test_list_models_with_all_args(self): TestListModels._url(PROJECT_ID) + '?filter=display_name%3DdisplayName3&page_size=10&page_token={0}' .format(PAGE_TOKEN)) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE assert isinstance(models_page, ml.ListModelsPage) assert len(models_page.models) == 1 assert models_page.models[0] == MODEL_3 @@ -1020,6 +1036,7 @@ def test_list_models_error(self): assert len(recorder) == 1 assert recorder[0].method == 'GET' assert recorder[0].url == TestListModels._url(PROJECT_ID) + assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE def test_no_project_id(self): def evaluate():