Skip to content

Commit f90a021

Browse files
authored
Mlkit add headers (#445)
* add Headers
1 parent e49add8 commit f90a021

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

firebase_admin/ml.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import requests
2929

30+
import firebase_admin
3031
from firebase_admin import _http_client
3132
from firebase_admin import _utils
3233
from firebase_admin import exceptions
@@ -783,11 +784,16 @@ def __init__(self, app):
783784
'Project ID is required to access ML service. Either set the '
784785
'projectId option, or use service account credentials.')
785786
self._project_url = _MLService.PROJECT_URL.format(self._project_id)
787+
ml_headers = {
788+
'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__),
789+
}
786790
self._client = _http_client.JsonHttpClient(
787791
credential=app.credential.get_credential(),
792+
headers=ml_headers,
788793
base_url=self._project_url)
789794
self._operation_client = _http_client.JsonHttpClient(
790795
credential=app.credential.get_credential(),
796+
headers=ml_headers,
791797
base_url=_MLService.OPERATION_URL)
792798

793799
def get_operation(self, op_name):

tests/test_ml.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626

2727
BASE_URL = 'https://firebaseml.googleapis.com/v1beta2/'
28+
HEADER_CLIENT_KEY = 'X-FIREBASE-CLIENT'
29+
HEADER_CLIENT_VALUE = 'fire-admin-python/3.2.1'
2830
PROJECT_ID = 'my-project-1'
2931

3032
PAGE_TOKEN = 'pageToken'
@@ -536,6 +538,7 @@ def test_wait_for_unlocked(self):
536538
assert len(recorder) == 1
537539
assert recorder[0].method == 'GET'
538540
assert recorder[0].url == TestModel._op_url(PROJECT_ID)
541+
assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
539542

540543
def test_wait_for_unlocked_timeout(self):
541544
recorder = instrument_ml_service(
@@ -589,8 +592,10 @@ def test_returns_locked(self):
589592
assert len(recorder) == 2
590593
assert recorder[0].method == 'POST'
591594
assert recorder[0].url == TestCreateModel._url(PROJECT_ID)
595+
assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
592596
assert recorder[1].method == 'GET'
593597
assert recorder[1].url == TestCreateModel._get_url(PROJECT_ID, MODEL_ID_1)
598+
assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
594599

595600
def test_operation_error(self):
596601
instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE)
@@ -681,8 +686,10 @@ def test_returns_locked(self):
681686
assert len(recorder) == 2
682687
assert recorder[0].method == 'PATCH'
683688
assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)
689+
assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
684690
assert recorder[1].method == 'GET'
685691
assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)
692+
assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
686693

687694
def test_operation_error(self):
688695
instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE)
@@ -778,6 +785,7 @@ def test_immediate_done(self, publish_function, published):
778785
assert len(recorder) == 1
779786
assert recorder[0].method == 'PATCH'
780787
assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1)
788+
assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
781789
body = json.loads(recorder[0].body.decode())
782790
assert body.get('state', {}).get('published', None) is published
783791

@@ -793,8 +801,10 @@ def test_returns_locked(self, publish_function):
793801
assert len(recorder) == 2
794802
assert recorder[0].method == 'PATCH'
795803
assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1)
804+
assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
796805
assert recorder[1].method == 'GET'
797806
assert recorder[1].url == TestPublishUnpublish._get_url(PROJECT_ID, MODEL_ID_1)
807+
assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
798808

799809
@pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS)
800810
def test_operation_error(self, publish_function):
@@ -847,6 +857,7 @@ def test_get_model(self):
847857
assert len(recorder) == 1
848858
assert recorder[0].method == 'GET'
849859
assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1)
860+
assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
850861
assert model == MODEL_1
851862
assert model.model_id == MODEL_ID_1
852863
assert model.display_name == DISPLAY_NAME_1
@@ -870,6 +881,7 @@ def test_get_model_error(self):
870881
assert len(recorder) == 1
871882
assert recorder[0].method == 'GET'
872883
assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1)
884+
assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
873885

874886
def test_no_project_id(self):
875887
def evaluate():
@@ -900,6 +912,7 @@ def test_delete_model(self):
900912
assert len(recorder) == 1
901913
assert recorder[0].method == 'DELETE'
902914
assert recorder[0].url == TestDeleteModel._url(PROJECT_ID, MODEL_ID_1)
915+
assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
903916

904917
@pytest.mark.parametrize('model_id, exc_type', INVALID_MODEL_ID_ARGS)
905918
def test_delete_model_validation_errors(self, model_id, exc_type):
@@ -920,6 +933,7 @@ def test_delete_model_error(self):
920933
assert len(recorder) == 1
921934
assert recorder[0].method == 'DELETE'
922935
assert recorder[0].url == self._url(PROJECT_ID, MODEL_ID_1)
936+
assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
923937

924938
def test_no_project_id(self):
925939
def evaluate():
@@ -957,6 +971,7 @@ def test_list_models_no_args(self):
957971
assert len(recorder) == 1
958972
assert recorder[0].method == 'GET'
959973
assert recorder[0].url == TestListModels._url(PROJECT_ID)
974+
assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
960975
TestListModels._check_page(models_page, 2)
961976
assert models_page.has_next_page
962977
assert models_page.next_page_token == NEXT_PAGE_TOKEN
@@ -975,6 +990,7 @@ def test_list_models_with_all_args(self):
975990
TestListModels._url(PROJECT_ID) +
976991
'?filter=display_name%3DdisplayName3&page_size=10&page_token={0}'
977992
.format(PAGE_TOKEN))
993+
assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
978994
assert isinstance(models_page, ml.ListModelsPage)
979995
assert len(models_page.models) == 1
980996
assert models_page.models[0] == MODEL_3
@@ -1020,6 +1036,7 @@ def test_list_models_error(self):
10201036
assert len(recorder) == 1
10211037
assert recorder[0].method == 'GET'
10221038
assert recorder[0].url == TestListModels._url(PROJECT_ID)
1039+
assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
10231040

10241041
def test_no_project_id(self):
10251042
def evaluate():

0 commit comments

Comments
 (0)