Skip to content

Commit a247f13

Browse files
authored
Firebase ML Kit List Models API implementation (#331)
* implemented list models plus tests
1 parent a84d3f6 commit a247f13

File tree

2 files changed

+394
-20
lines changed

2 files changed

+394
-20
lines changed

firebase_admin/mlkit.py

Lines changed: 147 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828

2929
_MLKIT_ATTRIBUTE = '_mlkit'
30+
_MAX_PAGE_SIZE = 100
3031

3132

3233
def _get_mlkit_service(app):
@@ -49,6 +50,12 @@ def get_model(model_id, app=None):
4950
return Model(mlkit_service.get_model(model_id))
5051

5152

53+
def list_models(list_filter=None, page_size=None, page_token=None, app=None):
54+
mlkit_service = _get_mlkit_service(app)
55+
return ListModelsPage(
56+
mlkit_service.list_models, list_filter, page_size, page_token)
57+
58+
5259
def delete_model(model_id, app=None):
5360
mlkit_service = _get_mlkit_service(app)
5461
mlkit_service.delete_model(model_id)
@@ -69,7 +76,107 @@ def __eq__(self, other):
6976
def __ne__(self, other):
7077
return not self.__eq__(other)
7178

72-
#TODO(ifielker): define the Model properties etc
79+
@property
80+
def name(self):
81+
return self._data['name']
82+
83+
@property
84+
def display_name(self):
85+
return self._data['displayName']
86+
87+
#TODO(ifielker): define the rest of the Model properties etc
88+
89+
90+
class ListModelsPage(object):
91+
"""Represents a page of models in a firebase project.
92+
93+
Provides methods for traversing the models included in this page, as well as
94+
retrieving subsequent pages of models. The iterator returned by
95+
``iterate_all()`` can be used to iterate through all the models in the
96+
Firebase project starting from this page.
97+
"""
98+
def __init__(self, list_models_func, list_filter, page_size, page_token):
99+
self._list_models_func = list_models_func
100+
self._list_filter = list_filter
101+
self._page_size = page_size
102+
self._page_token = page_token
103+
self._list_response = list_models_func(list_filter, page_size, page_token)
104+
105+
@property
106+
def models(self):
107+
"""A list of Models from this page."""
108+
return [Model(model) for model in self._list_response.get('models', [])]
109+
110+
@property
111+
def list_filter(self):
112+
"""The filter string used to filter the models."""
113+
return self._list_filter
114+
115+
@property
116+
def next_page_token(self):
117+
return self._list_response.get('nextPageToken', '')
118+
119+
@property
120+
def has_next_page(self):
121+
"""A boolean indicating whether more pages are available."""
122+
return bool(self.next_page_token)
123+
124+
def get_next_page(self):
125+
"""Retrieves the next page of models if available.
126+
127+
Returns:
128+
ListModelsPage: Next page of models, or None if this is the last page.
129+
"""
130+
if self.has_next_page:
131+
return ListModelsPage(
132+
self._list_models_func,
133+
self._list_filter,
134+
self._page_size,
135+
self.next_page_token)
136+
return None
137+
138+
def iterate_all(self):
139+
"""Retrieves an iterator for Models.
140+
141+
Returned iterator will iterate through all the models in the Firebase
142+
project starting from this page. The iterator will never buffer more than
143+
one page of models in memory at a time.
144+
145+
Returns:
146+
iterator: An iterator of Model instances.
147+
"""
148+
return _ModelIterator(self)
149+
150+
151+
class _ModelIterator(object):
152+
"""An iterator that allows iterating over models, one at a time.
153+
154+
This implementation loads a page of models into memory, and iterates on them.
155+
When the whole page has been traversed, it loads another page. This class
156+
never keeps more than one page of entries in memory.
157+
"""
158+
def __init__(self, current_page):
159+
if not isinstance(current_page, ListModelsPage):
160+
raise TypeError('Current page must be a ListModelsPage')
161+
self._current_page = current_page
162+
self._index = 0
163+
164+
def next(self):
165+
if self._index == len(self._current_page.models):
166+
if self._current_page.has_next_page:
167+
self._current_page = self._current_page.get_next_page()
168+
self._index = 0
169+
if self._index < len(self._current_page.models):
170+
result = self._current_page.models[self._index]
171+
self._index += 1
172+
return result
173+
raise StopIteration
174+
175+
def __next__(self):
176+
return self.next()
177+
178+
def __iter__(self):
179+
return self
73180

74181

75182
def _validate_model_id(model_id):
@@ -79,6 +186,28 @@ def _validate_model_id(model_id):
79186
raise ValueError('Model ID format is invalid.')
80187

81188

189+
def _validate_list_filter(list_filter):
190+
if list_filter is not None:
191+
if not isinstance(list_filter, six.string_types):
192+
raise TypeError('List filter must be a string or None.')
193+
194+
195+
def _validate_page_size(page_size):
196+
if page_size is not None:
197+
if type(page_size) is not int: # pylint: disable=unidiomatic-typecheck
198+
# Specifically type() to disallow boolean which is a subtype of int
199+
raise TypeError('Page size must be a number or None.')
200+
if page_size < 1 or page_size > _MAX_PAGE_SIZE:
201+
raise ValueError('Page size must be a positive integer between '
202+
'1 and {0}'.format(_MAX_PAGE_SIZE))
203+
204+
205+
def _validate_page_token(page_token):
206+
if page_token is not None:
207+
if not isinstance(page_token, six.string_types):
208+
raise TypeError('Page token must be a string or None.')
209+
210+
82211
class _MLKitService(object):
83212
"""Firebase MLKit service."""
84213

@@ -102,6 +231,23 @@ def get_model(self, model_id):
102231
except requests.exceptions.RequestException as error:
103232
raise _utils.handle_platform_error_from_requests(error)
104233

234+
def list_models(self, list_filter, page_size, page_token):
235+
""" lists Firebase ML Kit models."""
236+
_validate_list_filter(list_filter)
237+
_validate_page_size(page_size)
238+
_validate_page_token(page_token)
239+
payload = {}
240+
if list_filter:
241+
payload['list_filter'] = list_filter
242+
if page_size:
243+
payload['page_size'] = page_size
244+
if page_token:
245+
payload['page_token'] = page_token
246+
try:
247+
return self._client.body('get', url='models', json=payload)
248+
except requests.exceptions.RequestException as error:
249+
raise _utils.handle_platform_error_from_requests(error)
250+
105251
def delete_model(self, model_id):
106252
_validate_model_id(model_id)
107253
try:

0 commit comments

Comments
 (0)