18
18
deleting, publishing and unpublishing Firebase ML Kit models.
19
19
"""
20
20
21
+ import datetime
22
+ import numbers
21
23
import re
22
24
import requests
23
25
import six
28
30
29
31
_MLKIT_ATTRIBUTE = '_mlkit'
30
32
_MAX_PAGE_SIZE = 100
33
+ _MODEL_ID_PATTERN = re .compile (r'^[A-Za-z0-9_-]{1,60}$' )
34
+ _DISPLAY_NAME_PATTERN = re .compile (r'^[A-Za-z0-9_-]{1,60}$' )
35
+ _TAG_PATTERN = re .compile (r'^[A-Za-z0-9_-]{1,60}$' )
36
+ _GCS_TFLITE_URI_PATTERN = re .compile (r'^gs://[a-z0-9_.-]{3,63}/.+' )
37
+ _RESOURCE_NAME_PATTERN = re .compile (
38
+ r'^projects/(?P<project_id>[^/]+)/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$' )
31
39
32
40
33
41
def _get_mlkit_service (app ):
@@ -47,7 +55,7 @@ def _get_mlkit_service(app):
47
55
48
56
def get_model (model_id , app = None ):
49
57
mlkit_service = _get_mlkit_service (app )
50
- return Model (mlkit_service .get_model (model_id ))
58
+ return Model . from_dict (mlkit_service .get_model (model_id ))
51
59
52
60
53
61
def list_models (list_filter = None , page_size = None , page_token = None , app = None ):
@@ -62,29 +70,222 @@ def delete_model(model_id, app=None):
62
70
63
71
64
72
class Model (object ):
65
- """A Firebase ML Kit Model object."""
66
- def __init__ (self , data ):
67
- """Created from a data dictionary."""
68
- self ._data = data
73
+ """A Firebase ML Kit Model object.
74
+
75
+ Args:
76
+ display_name: The display name of your model - used to identify your model in code.
77
+ tags: Optional list of strings associated with your model. Can be used in list queries.
78
+ model_format: A subclass of ModelFormat. (e.g. TFLiteFormat) Specifies the model details.
79
+ """
80
+ def __init__ (self , display_name = None , tags = None , model_format = None ):
81
+ self ._data = {}
82
+ self ._model_format = None
83
+
84
+ if display_name is not None :
85
+ self .display_name = display_name
86
+ if tags is not None :
87
+ self .tags = tags
88
+ if model_format is not None :
89
+ self .model_format = model_format
90
+
91
+ @classmethod
92
+ def from_dict (cls , data ):
93
+ data_copy = dict (data )
94
+ tflite_format = None
95
+ tflite_format_data = data_copy .pop ('tfliteModel' , None )
96
+ if tflite_format_data :
97
+ tflite_format = TFLiteFormat .from_dict (tflite_format_data )
98
+ model = Model (model_format = tflite_format )
99
+ model ._data = data_copy # pylint: disable=protected-access
100
+ return model
69
101
70
102
def __eq__ (self , other ):
71
103
if isinstance (other , self .__class__ ):
72
- return self ._data == other ._data # pylint: disable=protected-access
104
+ # pylint: disable=protected-access
105
+ return self ._data == other ._data and self ._model_format == other ._model_format
73
106
else :
74
107
return False
75
108
76
109
def __ne__ (self , other ):
77
110
return not self .__eq__ (other )
78
111
79
112
@property
80
- def name (self ):
81
- return self ._data ['name' ]
113
+ def model_id (self ):
114
+ if not self ._data .get ('name' ):
115
+ return None
116
+ _ , model_id = _validate_and_parse_name (self ._data .get ('name' ))
117
+ return model_id
82
118
83
119
@property
84
120
def display_name (self ):
85
- return self ._data ['displayName' ]
121
+ return self ._data .get ('displayName' )
122
+
123
+ @display_name .setter
124
+ def display_name (self , display_name ):
125
+ self ._data ['displayName' ] = _validate_display_name (display_name )
126
+ return self
127
+
128
+ @property
129
+ def create_time (self ):
130
+ """Returns the creation timestamp"""
131
+ seconds = self ._data .get ('createTime' , {}).get ('seconds' )
132
+ if not isinstance (seconds , numbers .Number ):
133
+ return None
134
+
135
+ return datetime .datetime .fromtimestamp (float (seconds ))
136
+
137
+ @property
138
+ def update_time (self ):
139
+ """Returns the last update timestamp"""
140
+ seconds = self ._data .get ('updateTime' , {}).get ('seconds' )
141
+ if not isinstance (seconds , numbers .Number ):
142
+ return None
86
143
87
- #TODO(ifielker): define the rest of the Model properties etc
144
+ return datetime .datetime .fromtimestamp (float (seconds ))
145
+
146
+ @property
147
+ def validation_error (self ):
148
+ return self ._data .get ('state' , {}).get ('validationError' , {}).get ('message' )
149
+
150
+ @property
151
+ def published (self ):
152
+ return bool (self ._data .get ('state' , {}).get ('published' ))
153
+
154
+ @property
155
+ def etag (self ):
156
+ return self ._data .get ('etag' )
157
+
158
+ @property
159
+ def model_hash (self ):
160
+ return self ._data .get ('modelHash' )
161
+
162
+ @property
163
+ def tags (self ):
164
+ return self ._data .get ('tags' )
165
+
166
+ @tags .setter
167
+ def tags (self , tags ):
168
+ self ._data ['tags' ] = _validate_tags (tags )
169
+ return self
170
+
171
+ @property
172
+ def locked (self ):
173
+ return bool (self ._data .get ('activeOperations' ) and
174
+ len (self ._data .get ('activeOperations' )) > 0 )
175
+
176
+ @property
177
+ def model_format (self ):
178
+ return self ._model_format
179
+
180
+ @model_format .setter
181
+ def model_format (self , model_format ):
182
+ if model_format is not None :
183
+ _validate_model_format (model_format )
184
+ self ._model_format = model_format #Can be None
185
+ return self
186
+
187
+ def as_dict (self ):
188
+ copy = dict (self ._data )
189
+ if self ._model_format :
190
+ copy .update (self ._model_format .as_dict ())
191
+ return copy
192
+
193
+
194
+ class ModelFormat (object ):
195
+ """Abstract base class representing a Model Format such as TFLite."""
196
+ def as_dict (self ):
197
+ raise NotImplementedError
198
+
199
+
200
+ class TFLiteFormat (ModelFormat ):
201
+ """Model format representing a TFLite model.
202
+
203
+ Args:
204
+ model_source: A TFLiteModelSource sub class. Specifies the details of the model source.
205
+ """
206
+ def __init__ (self , model_source = None ):
207
+ self ._data = {}
208
+ self ._model_source = None
209
+
210
+ if model_source is not None :
211
+ self .model_source = model_source
212
+
213
+ @classmethod
214
+ def from_dict (cls , data ):
215
+ data_copy = dict (data )
216
+ model_source = None
217
+ gcs_tflite_uri = data_copy .pop ('gcsTfliteUri' , None )
218
+ if gcs_tflite_uri :
219
+ model_source = TFLiteGCSModelSource (gcs_tflite_uri = gcs_tflite_uri )
220
+ tflite_format = TFLiteFormat (model_source = model_source )
221
+ tflite_format ._data = data_copy # pylint: disable=protected-access
222
+ return tflite_format
223
+
224
+
225
+ def __eq__ (self , other ):
226
+ if isinstance (other , self .__class__ ):
227
+ # pylint: disable=protected-access
228
+ return self ._data == other ._data and self ._model_source == other ._model_source
229
+ else :
230
+ return False
231
+
232
+ def __ne__ (self , other ):
233
+ return not self .__eq__ (other )
234
+
235
+ @property
236
+ def model_source (self ):
237
+ return self ._model_source
238
+
239
+ @model_source .setter
240
+ def model_source (self , model_source ):
241
+ if model_source is not None :
242
+ if not isinstance (model_source , TFLiteModelSource ):
243
+ raise TypeError ('Model source must be a TFLiteModelSource object.' )
244
+ self ._model_source = model_source # Can be None
245
+
246
+ @property
247
+ def size_bytes (self ):
248
+ return self ._data .get ('sizeBytes' )
249
+
250
+ def as_dict (self ):
251
+ copy = dict (self ._data )
252
+ if self ._model_source :
253
+ copy .update (self ._model_source .as_dict ())
254
+ return {'tfliteModel' : copy }
255
+
256
+
257
+ class TFLiteModelSource (object ):
258
+ """Abstract base class representing a model source for TFLite format models."""
259
+ def as_dict (self ):
260
+ raise NotImplementedError
261
+
262
+
263
+ class TFLiteGCSModelSource (TFLiteModelSource ):
264
+ """TFLite model source representing a tflite model file stored in GCS."""
265
+ def __init__ (self , gcs_tflite_uri ):
266
+ self ._gcs_tflite_uri = _validate_gcs_tflite_uri (gcs_tflite_uri )
267
+
268
+ def __eq__ (self , other ):
269
+ if isinstance (other , self .__class__ ):
270
+ return self ._gcs_tflite_uri == other ._gcs_tflite_uri # pylint: disable=protected-access
271
+ else :
272
+ return False
273
+
274
+ def __ne__ (self , other ):
275
+ return not self .__eq__ (other )
276
+
277
+ @property
278
+ def gcs_tflite_uri (self ):
279
+ return self ._gcs_tflite_uri
280
+
281
+ @gcs_tflite_uri .setter
282
+ def gcs_tflite_uri (self , gcs_tflite_uri ):
283
+ self ._gcs_tflite_uri = _validate_gcs_tflite_uri (gcs_tflite_uri )
284
+
285
+ def as_dict (self ):
286
+ return {"gcsTfliteUri" : self ._gcs_tflite_uri }
287
+
288
+ #TODO(ifielker): implement from_saved_model etc.
88
289
89
290
90
291
class ListModelsPage (object ):
@@ -105,7 +306,7 @@ def __init__(self, list_models_func, list_filter, page_size, page_token):
105
306
@property
106
307
def models (self ):
107
308
"""A list of Models from this page."""
108
- return [Model (model ) for model in self ._list_response .get ('models' , [])]
309
+ return [Model . from_dict (model ) for model in self ._list_response .get ('models' , [])]
109
310
110
311
@property
111
312
def list_filter (self ):
@@ -179,13 +380,48 @@ def __iter__(self):
179
380
return self
180
381
181
382
383
+ def _validate_and_parse_name (name ):
384
+ # The resource name is added automatically from API call responses.
385
+ # The only way it could be invalid is if someone tries to
386
+ # create a model from a dictionary manually and does it incorrectly.
387
+ matcher = _RESOURCE_NAME_PATTERN .match (name )
388
+ if not matcher :
389
+ raise ValueError ('Model resource name format is invalid.' )
390
+ return matcher .group ('project_id' ), matcher .group ('model_id' )
391
+
392
+
182
393
def _validate_model_id (model_id ):
183
- if not isinstance (model_id , six .string_types ):
184
- raise TypeError ('Model ID must be a string.' )
185
- if not re .match (r'^[A-Za-z0-9_-]{1,60}$' , model_id ):
394
+ if not _MODEL_ID_PATTERN .match (model_id ):
186
395
raise ValueError ('Model ID format is invalid.' )
187
396
188
397
398
+ def _validate_display_name (display_name ):
399
+ if not _DISPLAY_NAME_PATTERN .match (display_name ):
400
+ raise ValueError ('Display name format is invalid.' )
401
+ return display_name
402
+
403
+
404
+ def _validate_tags (tags ):
405
+ if not isinstance (tags , list ) or not \
406
+ all (isinstance (tag , six .string_types ) for tag in tags ):
407
+ raise TypeError ('Tags must be a list of strings.' )
408
+ if not all (_TAG_PATTERN .match (tag ) for tag in tags ):
409
+ raise ValueError ('Tag format is invalid.' )
410
+ return tags
411
+
412
+
413
+ def _validate_gcs_tflite_uri (uri ):
414
+ # GCS Bucket naming rules are complex. The regex is not comprehensive.
415
+ # See https://cloud.google.com/storage/docs/naming for full details.
416
+ if not _GCS_TFLITE_URI_PATTERN .match (uri ):
417
+ raise ValueError ('GCS TFLite URI format is invalid.' )
418
+ return uri
419
+
420
+ def _validate_model_format (model_format ):
421
+ if not isinstance (model_format , ModelFormat ):
422
+ raise TypeError ('Model format must be a ModelFormat object.' )
423
+ return model_format
424
+
189
425
def _validate_list_filter (list_filter ):
190
426
if list_filter is not None :
191
427
if not isinstance (list_filter , six .string_types ):
0 commit comments