Skip to content

Mlkit add headers #445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions firebase_admin/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import requests

import firebase_admin
from firebase_admin import _http_client
from firebase_admin import _utils
from firebase_admin import exceptions
Expand Down Expand Up @@ -783,11 +784,16 @@ def __init__(self, app):
'Project ID is required to access ML service. Either set the '
'projectId option, or use service account credentials.')
self._project_url = _MLService.PROJECT_URL.format(self._project_id)
ml_headers = {
'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__),
}
self._client = _http_client.JsonHttpClient(
credential=app.credential.get_credential(),
headers=ml_headers,
base_url=self._project_url)
self._operation_client = _http_client.JsonHttpClient(
credential=app.credential.get_credential(),
headers=ml_headers,
base_url=_MLService.OPERATION_URL)

def get_operation(self, op_name):
Expand Down
17 changes: 17 additions & 0 deletions tests/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@


BASE_URL = 'https://firebaseml.googleapis.com/v1beta2/'
HEADER_CLIENT_KEY = 'X-FIREBASE-CLIENT'
HEADER_CLIENT_VALUE = 'fire-admin-python/3.2.1'
PROJECT_ID = 'my-project-1'

PAGE_TOKEN = 'pageToken'
Expand Down Expand Up @@ -536,6 +538,7 @@ def test_wait_for_unlocked(self):
assert len(recorder) == 1
assert recorder[0].method == 'GET'
assert recorder[0].url == TestModel._op_url(PROJECT_ID)
assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE

def test_wait_for_unlocked_timeout(self):
recorder = instrument_ml_service(
Expand Down Expand Up @@ -589,8 +592,10 @@ def test_returns_locked(self):
assert len(recorder) == 2
assert recorder[0].method == 'POST'
assert recorder[0].url == TestCreateModel._url(PROJECT_ID)
assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
assert recorder[1].method == 'GET'
assert recorder[1].url == TestCreateModel._get_url(PROJECT_ID, MODEL_ID_1)
assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE

def test_operation_error(self):
instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE)
Expand Down Expand Up @@ -681,8 +686,10 @@ def test_returns_locked(self):
assert len(recorder) == 2
assert recorder[0].method == 'PATCH'
assert recorder[0].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)
assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
assert recorder[1].method == 'GET'
assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1)
assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE

def test_operation_error(self):
instrument_ml_service(status=200, payload=OPERATION_ERROR_RESPONSE)
Expand Down Expand Up @@ -778,6 +785,7 @@ def test_immediate_done(self, publish_function, published):
assert len(recorder) == 1
assert recorder[0].method == 'PATCH'
assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1)
assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
body = json.loads(recorder[0].body.decode())
assert body.get('state', {}).get('published', None) is published

Expand All @@ -793,8 +801,10 @@ def test_returns_locked(self, publish_function):
assert len(recorder) == 2
assert recorder[0].method == 'PATCH'
assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1)
assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
assert recorder[1].method == 'GET'
assert recorder[1].url == TestPublishUnpublish._get_url(PROJECT_ID, MODEL_ID_1)
assert recorder[1].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE

@pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS)
def test_operation_error(self, publish_function):
Expand Down Expand Up @@ -847,6 +857,7 @@ def test_get_model(self):
assert len(recorder) == 1
assert recorder[0].method == 'GET'
assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1)
assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
assert model == MODEL_1
assert model.model_id == MODEL_ID_1
assert model.display_name == DISPLAY_NAME_1
Expand All @@ -870,6 +881,7 @@ def test_get_model_error(self):
assert len(recorder) == 1
assert recorder[0].method == 'GET'
assert recorder[0].url == TestGetModel._url(PROJECT_ID, MODEL_ID_1)
assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE

def test_no_project_id(self):
def evaluate():
Expand Down Expand Up @@ -900,6 +912,7 @@ def test_delete_model(self):
assert len(recorder) == 1
assert recorder[0].method == 'DELETE'
assert recorder[0].url == TestDeleteModel._url(PROJECT_ID, MODEL_ID_1)
assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE

@pytest.mark.parametrize('model_id, exc_type', INVALID_MODEL_ID_ARGS)
def test_delete_model_validation_errors(self, model_id, exc_type):
Expand All @@ -920,6 +933,7 @@ def test_delete_model_error(self):
assert len(recorder) == 1
assert recorder[0].method == 'DELETE'
assert recorder[0].url == self._url(PROJECT_ID, MODEL_ID_1)
assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE

def test_no_project_id(self):
def evaluate():
Expand Down Expand Up @@ -957,6 +971,7 @@ def test_list_models_no_args(self):
assert len(recorder) == 1
assert recorder[0].method == 'GET'
assert recorder[0].url == TestListModels._url(PROJECT_ID)
assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
TestListModels._check_page(models_page, 2)
assert models_page.has_next_page
assert models_page.next_page_token == NEXT_PAGE_TOKEN
Expand All @@ -975,6 +990,7 @@ def test_list_models_with_all_args(self):
TestListModels._url(PROJECT_ID) +
'?filter=display_name%3DdisplayName3&page_size=10&page_token={0}'
.format(PAGE_TOKEN))
assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE
assert isinstance(models_page, ml.ListModelsPage)
assert len(models_page.models) == 1
assert models_page.models[0] == MODEL_3
Expand Down Expand Up @@ -1020,6 +1036,7 @@ def test_list_models_error(self):
assert len(recorder) == 1
assert recorder[0].method == 'GET'
assert recorder[0].url == TestListModels._url(PROJECT_ID)
assert recorder[0].headers[HEADER_CLIENT_KEY] == HEADER_CLIENT_VALUE

def test_no_project_id(self):
def evaluate():
Expand Down