Skip to content

Firebase ML Kit Create Model API implementation #337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Sep 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions firebase_admin/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,27 @@ def handle_platform_error_from_requests(error, handle_func=None):
return exc if exc else _handle_func_requests(error, message, error_dict)


def handle_operation_error(error):
"""Constructs a ``FirebaseError`` from the given operation error.

Args:
error: An error returned by a long running operation.

Returns:
FirebaseError: A ``FirebaseError`` that can be raised to the user code.
"""
if not isinstance(error, dict):
return exceptions.UnknownError(
message='Unknown error while making a remote service call: {0}'.format(error),
cause=error)

status_code = error.get('code')
message = error.get('message')
error_code = _http_status_to_error_code(status_code)
err_type = _error_code_to_exception_type(error_code)
return err_type(message=message)


def _handle_func_requests(error, message, error_dict):
"""Constructs a ``FirebaseError`` from the given GCP error.

Expand Down
186 changes: 180 additions & 6 deletions firebase_admin/mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
import datetime
import numbers
import re
import time
import requests
import six


from firebase_admin import _http_client
from firebase_admin import _utils
from firebase_admin import exceptions


_MLKIT_ATTRIBUTE = '_mlkit'
Expand All @@ -36,6 +39,9 @@
_GCS_TFLITE_URI_PATTERN = re.compile(r'^gs://[a-z0-9_.-]{3,63}/.+')
_RESOURCE_NAME_PATTERN = re.compile(
r'^projects/(?P<project_id>[^/]+)/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$')
_OPERATION_NAME_PATTERN = re.compile(
r'^operations/project/(?P<project_id>[^/]+)/model/(?P<model_id>[A-Za-z0-9_-]{1,60})' +
r'/operation/[^/]+$')


def _get_mlkit_service(app):
Expand All @@ -53,18 +59,60 @@ def _get_mlkit_service(app):
return _utils.get_app_service(app, _MLKIT_ATTRIBUTE, _MLKitService)


def create_model(model, app=None):
"""Creates a model in Firebase ML Kit.

Args:
model: An mlkit.Model to create.
app: A Firebase app instance (or None to use the default app).

Returns:
Model: The model that was created in Firebase ML Kit.
"""
mlkit_service = _get_mlkit_service(app)
return Model.from_dict(mlkit_service.create_model(model), app=app)


def get_model(model_id, app=None):
"""Gets a model from Firebase ML Kit.

Args:
model_id: The id of the model to get.
app: A Firebase app instance (or None to use the default app).

Returns:
Model: The requested model.
"""
mlkit_service = _get_mlkit_service(app)
return Model.from_dict(mlkit_service.get_model(model_id))
return Model.from_dict(mlkit_service.get_model(model_id), app=app)


def list_models(list_filter=None, page_size=None, page_token=None, app=None):
"""Lists models from Firebase ML Kit.

Args:
list_filter: a list filter string such as "tags:'tag_1'". None will return all models.
page_size: A number between 1 and 100 inclusive that specifies the maximum
number of models to return per page. None for default.
page_token: A next page token returned from a previous page of results. None
for first page of results.
app: A Firebase app instance (or None to use the default app).

Returns:
ListModelsPage: A (filtered) list of models.
"""
mlkit_service = _get_mlkit_service(app)
return ListModelsPage(
mlkit_service.list_models, list_filter, page_size, page_token)
mlkit_service.list_models, list_filter, page_size, page_token, app=app)


def delete_model(model_id, app=None):
"""Deletes a model from Firebase ML Kit.

Args:
model_id: The id of the model you wish to delete.
app: A Firebase app instance (or None to use the default app).
"""
mlkit_service = _get_mlkit_service(app)
mlkit_service.delete_model(model_id)

Expand All @@ -78,6 +126,7 @@ class Model(object):
model_format: A subclass of ModelFormat. (e.g. TFLiteFormat) Specifies the model details.
"""
def __init__(self, display_name=None, tags=None, model_format=None):
self._app = None # Only needed for wait_for_unlo
self._data = {}
self._model_format = None

Expand All @@ -89,16 +138,22 @@ def __init__(self, display_name=None, tags=None, model_format=None):
self.model_format = model_format

@classmethod
def from_dict(cls, data):
def from_dict(cls, data, app=None):
data_copy = dict(data)
tflite_format = None
tflite_format_data = data_copy.pop('tfliteModel', None)
if tflite_format_data:
tflite_format = TFLiteFormat.from_dict(tflite_format_data)
model = Model(model_format=tflite_format)
model._data = data_copy # pylint: disable=protected-access
model._app = app # pylint: disable=protected-access
return model

def _update_from_dict(self, data):
copy = Model.from_dict(data)
self.model_format = copy.model_format
self._data = copy._data # pylint: disable=protected-access

def __eq__(self, other):
if isinstance(other, self.__class__):
# pylint: disable=protected-access
Expand Down Expand Up @@ -173,6 +228,26 @@ def locked(self):
return bool(self._data.get('activeOperations') and
len(self._data.get('activeOperations')) > 0)

def wait_for_unlocked(self, max_time_seconds=None):
"""Waits for the model to be unlocked. (All active operations complete)

Args:
max_time_seconds: The maximum number of seconds to wait for the model to unlock.
(None for no limit)

Raises:
exceptions.DeadlineExceeded: If max_time_seconds passed and the model is still locked.
"""
if not self.locked:
return
mlkit_service = _get_mlkit_service(self._app)
op_name = self._data.get('activeOperations')[0].get('name')
model_dict = mlkit_service.handle_operation(
mlkit_service.get_operation(op_name),
wait_for_operation=True,
max_time_seconds=max_time_seconds)
self._update_from_dict(model_dict)

@property
def model_format(self):
return self._model_format
Expand Down Expand Up @@ -296,17 +371,20 @@ class ListModelsPage(object):
``iterate_all()`` can be used to iterate through all the models in the
Firebase project starting from this page.
"""
def __init__(self, list_models_func, list_filter, page_size, page_token):
def __init__(self, list_models_func, list_filter, page_size, page_token, app):
self._list_models_func = list_models_func
self._list_filter = list_filter
self._page_size = page_size
self._page_token = page_token
self._app = app
self._list_response = list_models_func(list_filter, page_size, page_token)

@property
def models(self):
"""A list of Models from this page."""
return [Model.from_dict(model) for model in self._list_response.get('models', [])]
return [
Model.from_dict(model, app=self._app) for model in self._list_response.get('models', [])
]

@property
def list_filter(self):
Expand All @@ -333,7 +411,8 @@ def get_next_page(self):
self._list_models_func,
self._list_filter,
self._page_size,
self.next_page_token)
self.next_page_token,
self._app)
return None

def iterate_all(self):
Expand Down Expand Up @@ -390,11 +469,25 @@ def _validate_and_parse_name(name):
return matcher.group('project_id'), matcher.group('model_id')


def _validate_model(model):
if not isinstance(model, Model):
raise TypeError('Model must be an mlkit.Model.')
if not model.display_name:
raise ValueError('Model must have a display name.')


def _validate_model_id(model_id):
if not _MODEL_ID_PATTERN.match(model_id):
raise ValueError('Model ID format is invalid.')


def _validate_and_parse_operation_name(op_name):
matcher = _OPERATION_NAME_PATTERN.match(op_name)
if not matcher:
raise ValueError('Operation name format is invalid.')
return matcher.group('project_id'), matcher.group('model_id')


def _validate_display_name(display_name):
if not _DISPLAY_NAME_PATTERN.match(display_name):
raise ValueError('Display name format is invalid.')
Expand All @@ -417,11 +510,13 @@ def _validate_gcs_tflite_uri(uri):
raise ValueError('GCS TFLite URI format is invalid.')
return uri


def _validate_model_format(model_format):
if not isinstance(model_format, ModelFormat):
raise TypeError('Model format must be a ModelFormat object.')
return model_format


def _validate_list_filter(list_filter):
if list_filter is not None:
if not isinstance(list_filter, six.string_types):
Expand All @@ -448,6 +543,9 @@ class _MLKitService(object):
"""Firebase MLKit service."""

PROJECT_URL = 'https://mlkit.googleapis.com/v1beta1/projects/{0}/'
OPERATION_URL = 'https://mlkit.googleapis.com/v1beta1/'
POLL_EXPONENTIAL_BACKOFF_FACTOR = 1.5
POLL_BASE_WAIT_TIME_SECONDS = 3

def __init__(self, app):
project_id = app.project_id
Expand All @@ -459,6 +557,82 @@ def __init__(self, app):
self._client = _http_client.JsonHttpClient(
credential=app.credential.get_credential(),
base_url=self._project_url)
self._operation_client = _http_client.JsonHttpClient(
credential=app.credential.get_credential(),
base_url=_MLKitService.OPERATION_URL)

def get_operation(self, op_name):
_validate_and_parse_operation_name(op_name)
try:
return self._operation_client.body('get', url=op_name)
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)

def _exponential_backoff(self, current_attempt, stop_time):
"""Sleeps for the appropriate amount of time. Or throws deadline exceeded."""
delay_factor = pow(_MLKitService.POLL_EXPONENTIAL_BACKOFF_FACTOR, current_attempt)
wait_time_seconds = delay_factor * _MLKitService.POLL_BASE_WAIT_TIME_SECONDS

if stop_time is not None:
max_seconds_left = (stop_time - datetime.datetime.now()).total_seconds()
if max_seconds_left < 1: # allow a bit of time for rpc
raise exceptions.DeadlineExceededError('Polling max time exceeded.')
else:
wait_time_seconds = min(wait_time_seconds, max_seconds_left - 1)
time.sleep(wait_time_seconds)


def handle_operation(self, operation, wait_for_operation=False, max_time_seconds=None):
"""Handles long running operations.

Args:
operation: The operation to handle.
wait_for_operation: Should we allow polling for the operation to complete.
If no polling is requested, a locked model will be returned instead.
max_time_seconds: The maximum seconds to try polling for operation complete.
(None for no limit)

Returns:
dict: A dictionary of the returned model properties.

Raises:
TypeError: if the operation is not a dictionary.
ValueError: If the operation is malformed.
err: If the operation exceeds polling attempts or stop_time
"""
if not isinstance(operation, dict):
raise TypeError('Operation must be a dictionary.')
op_name = operation.get('name')
_, model_id = _validate_and_parse_operation_name(op_name)

current_attempt = 0
start_time = datetime.datetime.now()
stop_time = (None if max_time_seconds is None else
start_time + datetime.timedelta(seconds=max_time_seconds))
while wait_for_operation and not operation.get('done'):
# We just got this operation. Wait before getting another
# so we don't exceed the GetOperation maximum request rate.
self._exponential_backoff(current_attempt, stop_time)
operation = self.get_operation(op_name)
current_attempt += 1

if operation.get('done'):
if operation.get('response'):
return operation.get('response')
elif operation.get('error'):
raise _utils.handle_operation_error(operation.get('error'))

# If the operation is not complete or timed out, return a (locked) model instead
return get_model(model_id).as_dict()


def create_model(self, model):
_validate_model(model)
try:
return self.handle_operation(
self._client.body('post', url='models', json=model.as_dict()))
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)

def get_model(self, model_id):
_validate_model_id(model_id)
Expand Down
Loading