Skip to content

Commit 4618b1e

Browse files
authored
Implementation of Model, ModelFormat, TFLiteModelSource and subclasses (#335)
* Implementation of Model, ModelFormat, ModelSource and subclasses
1 parent a247f13 commit 4618b1e

File tree

2 files changed

+505
-41
lines changed

2 files changed

+505
-41
lines changed

firebase_admin/mlkit.py

Lines changed: 250 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
deleting, publishing and unpublishing Firebase ML Kit models.
1919
"""
2020

21+
import datetime
22+
import numbers
2123
import re
2224
import requests
2325
import six
@@ -28,6 +30,12 @@
2830

2931
_MLKIT_ATTRIBUTE = '_mlkit'
3032
_MAX_PAGE_SIZE = 100
33+
_MODEL_ID_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$')
34+
_DISPLAY_NAME_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$')
35+
_TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$')
36+
_GCS_TFLITE_URI_PATTERN = re.compile(r'^gs://[a-z0-9_.-]{3,63}/.+')
37+
_RESOURCE_NAME_PATTERN = re.compile(
38+
r'^projects/(?P<project_id>[^/]+)/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$')
3139

3240

3341
def _get_mlkit_service(app):
@@ -47,7 +55,7 @@ def _get_mlkit_service(app):
4755

4856
def get_model(model_id, app=None):
4957
mlkit_service = _get_mlkit_service(app)
50-
return Model(mlkit_service.get_model(model_id))
58+
return Model.from_dict(mlkit_service.get_model(model_id))
5159

5260

5361
def list_models(list_filter=None, page_size=None, page_token=None, app=None):
@@ -62,29 +70,222 @@ def delete_model(model_id, app=None):
6270

6371

6472
class Model(object):
65-
"""A Firebase ML Kit Model object."""
66-
def __init__(self, data):
67-
"""Created from a data dictionary."""
68-
self._data = data
73+
"""A Firebase ML Kit Model object.
74+
75+
Args:
76+
display_name: The display name of your model - used to identify your model in code.
77+
tags: Optional list of strings associated with your model. Can be used in list queries.
78+
model_format: A subclass of ModelFormat. (e.g. TFLiteFormat) Specifies the model details.
79+
"""
80+
def __init__(self, display_name=None, tags=None, model_format=None):
81+
self._data = {}
82+
self._model_format = None
83+
84+
if display_name is not None:
85+
self.display_name = display_name
86+
if tags is not None:
87+
self.tags = tags
88+
if model_format is not None:
89+
self.model_format = model_format
90+
91+
@classmethod
92+
def from_dict(cls, data):
93+
data_copy = dict(data)
94+
tflite_format = None
95+
tflite_format_data = data_copy.pop('tfliteModel', None)
96+
if tflite_format_data:
97+
tflite_format = TFLiteFormat.from_dict(tflite_format_data)
98+
model = Model(model_format=tflite_format)
99+
model._data = data_copy # pylint: disable=protected-access
100+
return model
69101

70102
def __eq__(self, other):
71103
if isinstance(other, self.__class__):
72-
return self._data == other._data # pylint: disable=protected-access
104+
# pylint: disable=protected-access
105+
return self._data == other._data and self._model_format == other._model_format
73106
else:
74107
return False
75108

76109
def __ne__(self, other):
77110
return not self.__eq__(other)
78111

79112
@property
80-
def name(self):
81-
return self._data['name']
113+
def model_id(self):
114+
if not self._data.get('name'):
115+
return None
116+
_, model_id = _validate_and_parse_name(self._data.get('name'))
117+
return model_id
82118

83119
@property
84120
def display_name(self):
85-
return self._data['displayName']
121+
return self._data.get('displayName')
122+
123+
@display_name.setter
124+
def display_name(self, display_name):
125+
self._data['displayName'] = _validate_display_name(display_name)
126+
return self
127+
128+
@property
129+
def create_time(self):
130+
"""Returns the creation timestamp"""
131+
seconds = self._data.get('createTime', {}).get('seconds')
132+
if not isinstance(seconds, numbers.Number):
133+
return None
134+
135+
return datetime.datetime.fromtimestamp(float(seconds))
136+
137+
@property
138+
def update_time(self):
139+
"""Returns the last update timestamp"""
140+
seconds = self._data.get('updateTime', {}).get('seconds')
141+
if not isinstance(seconds, numbers.Number):
142+
return None
86143

87-
#TODO(ifielker): define the rest of the Model properties etc
144+
return datetime.datetime.fromtimestamp(float(seconds))
145+
146+
@property
147+
def validation_error(self):
148+
return self._data.get('state', {}).get('validationError', {}).get('message')
149+
150+
@property
151+
def published(self):
152+
return bool(self._data.get('state', {}).get('published'))
153+
154+
@property
155+
def etag(self):
156+
return self._data.get('etag')
157+
158+
@property
159+
def model_hash(self):
160+
return self._data.get('modelHash')
161+
162+
@property
163+
def tags(self):
164+
return self._data.get('tags')
165+
166+
@tags.setter
167+
def tags(self, tags):
168+
self._data['tags'] = _validate_tags(tags)
169+
return self
170+
171+
@property
172+
def locked(self):
173+
return bool(self._data.get('activeOperations') and
174+
len(self._data.get('activeOperations')) > 0)
175+
176+
@property
177+
def model_format(self):
178+
return self._model_format
179+
180+
@model_format.setter
181+
def model_format(self, model_format):
182+
if model_format is not None:
183+
_validate_model_format(model_format)
184+
self._model_format = model_format #Can be None
185+
return self
186+
187+
def as_dict(self):
188+
copy = dict(self._data)
189+
if self._model_format:
190+
copy.update(self._model_format.as_dict())
191+
return copy
192+
193+
194+
class ModelFormat(object):
195+
"""Abstract base class representing a Model Format such as TFLite."""
196+
def as_dict(self):
197+
raise NotImplementedError
198+
199+
200+
class TFLiteFormat(ModelFormat):
201+
"""Model format representing a TFLite model.
202+
203+
Args:
204+
model_source: A TFLiteModelSource sub class. Specifies the details of the model source.
205+
"""
206+
def __init__(self, model_source=None):
207+
self._data = {}
208+
self._model_source = None
209+
210+
if model_source is not None:
211+
self.model_source = model_source
212+
213+
@classmethod
214+
def from_dict(cls, data):
215+
data_copy = dict(data)
216+
model_source = None
217+
gcs_tflite_uri = data_copy.pop('gcsTfliteUri', None)
218+
if gcs_tflite_uri:
219+
model_source = TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri)
220+
tflite_format = TFLiteFormat(model_source=model_source)
221+
tflite_format._data = data_copy # pylint: disable=protected-access
222+
return tflite_format
223+
224+
225+
def __eq__(self, other):
226+
if isinstance(other, self.__class__):
227+
# pylint: disable=protected-access
228+
return self._data == other._data and self._model_source == other._model_source
229+
else:
230+
return False
231+
232+
def __ne__(self, other):
233+
return not self.__eq__(other)
234+
235+
@property
236+
def model_source(self):
237+
return self._model_source
238+
239+
@model_source.setter
240+
def model_source(self, model_source):
241+
if model_source is not None:
242+
if not isinstance(model_source, TFLiteModelSource):
243+
raise TypeError('Model source must be a TFLiteModelSource object.')
244+
self._model_source = model_source # Can be None
245+
246+
@property
247+
def size_bytes(self):
248+
return self._data.get('sizeBytes')
249+
250+
def as_dict(self):
251+
copy = dict(self._data)
252+
if self._model_source:
253+
copy.update(self._model_source.as_dict())
254+
return {'tfliteModel': copy}
255+
256+
257+
class TFLiteModelSource(object):
258+
"""Abstract base class representing a model source for TFLite format models."""
259+
def as_dict(self):
260+
raise NotImplementedError
261+
262+
263+
class TFLiteGCSModelSource(TFLiteModelSource):
264+
"""TFLite model source representing a tflite model file stored in GCS."""
265+
def __init__(self, gcs_tflite_uri):
266+
self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri)
267+
268+
def __eq__(self, other):
269+
if isinstance(other, self.__class__):
270+
return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access
271+
else:
272+
return False
273+
274+
def __ne__(self, other):
275+
return not self.__eq__(other)
276+
277+
@property
278+
def gcs_tflite_uri(self):
279+
return self._gcs_tflite_uri
280+
281+
@gcs_tflite_uri.setter
282+
def gcs_tflite_uri(self, gcs_tflite_uri):
283+
self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri)
284+
285+
def as_dict(self):
286+
return {"gcsTfliteUri": self._gcs_tflite_uri}
287+
288+
#TODO(ifielker): implement from_saved_model etc.
88289

89290

90291
class ListModelsPage(object):
@@ -105,7 +306,7 @@ def __init__(self, list_models_func, list_filter, page_size, page_token):
105306
@property
106307
def models(self):
107308
"""A list of Models from this page."""
108-
return [Model(model) for model in self._list_response.get('models', [])]
309+
return [Model.from_dict(model) for model in self._list_response.get('models', [])]
109310

110311
@property
111312
def list_filter(self):
@@ -179,13 +380,48 @@ def __iter__(self):
179380
return self
180381

181382

383+
def _validate_and_parse_name(name):
384+
# The resource name is added automatically from API call responses.
385+
# The only way it could be invalid is if someone tries to
386+
# create a model from a dictionary manually and does it incorrectly.
387+
matcher = _RESOURCE_NAME_PATTERN.match(name)
388+
if not matcher:
389+
raise ValueError('Model resource name format is invalid.')
390+
return matcher.group('project_id'), matcher.group('model_id')
391+
392+
182393
def _validate_model_id(model_id):
183-
if not isinstance(model_id, six.string_types):
184-
raise TypeError('Model ID must be a string.')
185-
if not re.match(r'^[A-Za-z0-9_-]{1,60}$', model_id):
394+
if not _MODEL_ID_PATTERN.match(model_id):
186395
raise ValueError('Model ID format is invalid.')
187396

188397

398+
def _validate_display_name(display_name):
399+
if not _DISPLAY_NAME_PATTERN.match(display_name):
400+
raise ValueError('Display name format is invalid.')
401+
return display_name
402+
403+
404+
def _validate_tags(tags):
405+
if not isinstance(tags, list) or not \
406+
all(isinstance(tag, six.string_types) for tag in tags):
407+
raise TypeError('Tags must be a list of strings.')
408+
if not all(_TAG_PATTERN.match(tag) for tag in tags):
409+
raise ValueError('Tag format is invalid.')
410+
return tags
411+
412+
413+
def _validate_gcs_tflite_uri(uri):
414+
# GCS Bucket naming rules are complex. The regex is not comprehensive.
415+
# See https://cloud.google.com/storage/docs/naming for full details.
416+
if not _GCS_TFLITE_URI_PATTERN.match(uri):
417+
raise ValueError('GCS TFLite URI format is invalid.')
418+
return uri
419+
420+
def _validate_model_format(model_format):
421+
if not isinstance(model_format, ModelFormat):
422+
raise TypeError('Model format must be a ModelFormat object.')
423+
return model_format
424+
189425
def _validate_list_filter(list_filter):
190426
if list_filter is not None:
191427
if not isinstance(list_filter, six.string_types):

0 commit comments

Comments
 (0)