18
18
deleting, publishing and unpublishing Firebase ML Kit models.
19
19
"""
20
20
21
+
21
22
import datetime
22
23
import numbers
23
24
import re
30
31
from firebase_admin import _utils
31
32
from firebase_admin import exceptions
32
33
34
+ # pylint: disable=import-error,no-name-in-module
35
+ try :
36
+ from firebase_admin import storage
37
+ _GCS_ENABLED = True
38
+ except ImportError :
39
+ _GCS_ENABLED = False
40
+
41
+ # pylint: disable=import-error,no-name-in-module
42
+ try :
43
+ import tensorflow as tf
44
+ _TF_ENABLED = True
45
+ except ImportError :
46
+ _TF_ENABLED = False
33
47
34
48
_MLKIT_ATTRIBUTE = '_mlkit'
35
49
_MAX_PAGE_SIZE = 100
36
50
_MODEL_ID_PATTERN = re .compile (r'^[A-Za-z0-9_-]{1,60}$' )
37
51
_DISPLAY_NAME_PATTERN = re .compile (r'^[A-Za-z0-9_-]{1,60}$' )
38
52
_TAG_PATTERN = re .compile (r'^[A-Za-z0-9_-]{1,60}$' )
39
- _GCS_TFLITE_URI_PATTERN = re .compile (r'^gs://[a-z0-9_.-]{3,63}/.+' )
53
+ _GCS_TFLITE_URI_PATTERN = re .compile (
54
+ r'^gs://(?P<bucket_name>[a-z0-9_.-]{3,63})/(?P<blob_name>.+)$' )
40
55
_RESOURCE_NAME_PATTERN = re .compile (
41
56
r'^projects/(?P<project_id>[^/]+)/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$' )
42
57
_OPERATION_NAME_PATTERN = re .compile (
@@ -301,16 +316,16 @@ def model_format(self, model_format):
301
316
self ._model_format = model_format #Can be None
302
317
return self
303
318
304
- def as_dict (self ):
319
+ def as_dict (self , for_upload = False ):
305
320
copy = dict (self ._data )
306
321
if self ._model_format :
307
- copy .update (self ._model_format .as_dict ())
322
+ copy .update (self ._model_format .as_dict (for_upload = for_upload ))
308
323
return copy
309
324
310
325
311
326
class ModelFormat (object ):
312
327
"""Abstract base class representing a Model Format such as TFLite."""
313
- def as_dict (self ):
328
+ def as_dict (self , for_upload = False ):
314
329
raise NotImplementedError
315
330
316
331
@@ -364,22 +379,70 @@ def model_source(self, model_source):
364
379
def size_bytes (self ):
365
380
return self ._data .get ('sizeBytes' )
366
381
367
- def as_dict (self ):
382
+ def as_dict (self , for_upload = False ):
368
383
copy = dict (self ._data )
369
384
if self ._model_source :
370
- copy .update (self ._model_source .as_dict ())
385
+ copy .update (self ._model_source .as_dict (for_upload = for_upload ))
371
386
return {'tfliteModel' : copy }
372
387
373
388
374
389
class TFLiteModelSource (object ):
375
390
"""Abstract base class representing a model source for TFLite format models."""
376
- def as_dict (self ):
391
+ def as_dict (self , for_upload = False ):
377
392
raise NotImplementedError
378
393
379
394
395
+ class _CloudStorageClient (object ):
396
+ """Cloud Storage helper class"""
397
+
398
+ GCS_URI = 'gs://{0}/{1}'
399
+ BLOB_NAME = 'Firebase/MLKit/Models/{0}'
400
+
401
+ @staticmethod
402
+ def _assert_gcs_enabled ():
403
+ if not _GCS_ENABLED :
404
+ raise ImportError ('Failed to import the Cloud Storage library for Python. Make sure '
405
+ 'to install the "google-cloud-storage" module.' )
406
+
407
+ @staticmethod
408
+ def _parse_gcs_tflite_uri (uri ):
409
+ # GCS Bucket naming rules are complex. The regex is not comprehensive.
410
+ # See https://cloud.google.com/storage/docs/naming for full details.
411
+ matcher = _GCS_TFLITE_URI_PATTERN .match (uri )
412
+ if not matcher :
413
+ raise ValueError ('GCS TFLite URI format is invalid.' )
414
+ return matcher .group ('bucket_name' ), matcher .group ('blob_name' )
415
+
416
+ @staticmethod
417
+ def upload (bucket_name , model_file_name , app ):
418
+ _CloudStorageClient ._assert_gcs_enabled ()
419
+ bucket = storage .bucket (bucket_name , app = app )
420
+ blob_name = _CloudStorageClient .BLOB_NAME .format (model_file_name )
421
+ blob = bucket .blob (blob_name )
422
+ blob .upload_from_filename (model_file_name )
423
+ return _CloudStorageClient .GCS_URI .format (bucket .name , blob_name )
424
+
425
+ @staticmethod
426
+ def sign_uri (gcs_tflite_uri , app ):
427
+ """Makes the gcs_tflite_uri readable for GET for 10 minutes via signed_uri."""
428
+ _CloudStorageClient ._assert_gcs_enabled ()
429
+ bucket_name , blob_name = _CloudStorageClient ._parse_gcs_tflite_uri (gcs_tflite_uri )
430
+ bucket = storage .bucket (bucket_name , app = app )
431
+ blob = bucket .blob (blob_name )
432
+ return blob .generate_signed_url (
433
+ version = 'v4' ,
434
+ expiration = datetime .timedelta (minutes = 10 ),
435
+ method = 'GET'
436
+ )
437
+
438
+
380
439
class TFLiteGCSModelSource (TFLiteModelSource ):
381
440
"""TFLite model source representing a tflite model file stored in GCS."""
382
- def __init__ (self , gcs_tflite_uri ):
441
+
442
+ _STORAGE_CLIENT = _CloudStorageClient ()
443
+
444
+ def __init__ (self , gcs_tflite_uri , app = None ):
445
+ self ._app = app
383
446
self ._gcs_tflite_uri = _validate_gcs_tflite_uri (gcs_tflite_uri )
384
447
385
448
def __eq__ (self , other ):
@@ -391,6 +454,81 @@ def __eq__(self, other):
391
454
def __ne__ (self , other ):
392
455
return not self .__eq__ (other )
393
456
457
+ @classmethod
458
+ def from_tflite_model_file (cls , model_file_name , bucket_name = None , app = None ):
459
+ """Uploads the model file to an existing Google Cloud Storage bucket.
460
+
461
+ Args:
462
+ model_file_name: The name of the model file.
463
+ bucket_name: The name of an existing bucket. None to use the default bucket configured
464
+ in the app.
465
+ app: A Firebase app instance (or None to use the default app).
466
+
467
+ Returns:
468
+ TFLiteGCSModelSource: The source created from the model_file
469
+
470
+ Raises:
471
+ ImportError: If the Cloud Storage Library has not been installed.
472
+ """
473
+ gcs_uri = TFLiteGCSModelSource ._STORAGE_CLIENT .upload (bucket_name , model_file_name , app )
474
+ return TFLiteGCSModelSource (gcs_tflite_uri = gcs_uri , app = app )
475
+
476
+ @staticmethod
477
+ def _assert_tf_version_1_enabled ():
478
+ if not _TF_ENABLED :
479
+ raise ImportError ('Failed to import the tensorflow library for Python. Make sure '
480
+ 'to install the tensorflow module.' )
481
+ if not tf .VERSION .startswith ('1.' ):
482
+ raise ImportError ('Expected tensorflow version 1.x, but found {0}' .format (tf .VERSION ))
483
+
484
+ @classmethod
485
+ def from_saved_model (cls , saved_model_dir , bucket_name = None , app = None ):
486
+ """Creates a Tensor Flow Lite model from the saved model, and uploads the model to GCS.
487
+
488
+ Args:
489
+ saved_model_dir: The saved model directory.
490
+ bucket_name: The name of an existing bucket. None to use the default bucket configured
491
+ in the app.
492
+ app: Optional. A Firebase app instance (or None to use the default app)
493
+
494
+ Returns:
495
+ TFLiteGCSModelSource: The source created from the saved_model_dir
496
+
497
+ Raises:
498
+ ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed.
499
+ """
500
+ TFLiteGCSModelSource ._assert_tf_version_1_enabled ()
501
+ converter = tf .lite .TFLiteConverter .from_saved_model (saved_model_dir )
502
+ tflite_model = converter .convert ()
503
+ open ('firebase_mlkit_model.tflite' , 'wb' ).write (tflite_model )
504
+ return TFLiteGCSModelSource .from_tflite_model_file (
505
+ 'firebase_mlkit_model.tflite' , bucket_name , app )
506
+
507
+ @classmethod
508
+ def from_keras_model (cls , keras_model , bucket_name = None , app = None ):
509
+ """Creates a Tensor Flow Lite model from the keras model, and uploads the model to GCS.
510
+
511
+ Args:
512
+ keras_model: A tf.keras model.
513
+ bucket_name: The name of an existing bucket. None to use the default bucket configured
514
+ in the app.
515
+ app: Optional. A Firebase app instance (or None to use the default app)
516
+
517
+ Returns:
518
+ TFLiteGCSModelSource: The source created from the keras_model
519
+
520
+ Raises:
521
+ ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed.
522
+ """
523
+ TFLiteGCSModelSource ._assert_tf_version_1_enabled ()
524
+ keras_file = 'keras_model.h5'
525
+ tf .keras .models .save_model (keras_model , keras_file )
526
+ converter = tf .lite .TFLiteConverter .from_keras_model_file (keras_file )
527
+ tflite_model = converter .convert ()
528
+ open ('firebase_mlkit_model.tflite' , 'wb' ).write (tflite_model )
529
+ return TFLiteGCSModelSource .from_tflite_model_file (
530
+ 'firebase_mlkit_model.tflite' , bucket_name , app )
531
+
394
532
@property
395
533
def gcs_tflite_uri (self ):
396
534
return self ._gcs_tflite_uri
@@ -399,10 +537,15 @@ def gcs_tflite_uri(self):
399
537
def gcs_tflite_uri (self , gcs_tflite_uri ):
400
538
self ._gcs_tflite_uri = _validate_gcs_tflite_uri (gcs_tflite_uri )
401
539
402
- def as_dict (self ):
403
- return {"gcsTfliteUri" : self ._gcs_tflite_uri }
540
+ def _get_signed_gcs_tflite_uri (self ):
541
+ """Signs the GCS uri, so the model file can be uploaded to Firebase ML Kit and verified."""
542
+ return TFLiteGCSModelSource ._STORAGE_CLIENT .sign_uri (self ._gcs_tflite_uri , self ._app )
543
+
544
+ def as_dict (self , for_upload = False ):
545
+ if for_upload :
546
+ return {'gcsTfliteUri' : self ._get_signed_gcs_tflite_uri ()}
404
547
405
- #TODO(ifielker): implement from_saved_model etc.
548
+ return { 'gcsTfliteUri' : self . _gcs_tflite_uri }
406
549
407
550
408
551
class ListModelsPage (object ):
@@ -671,13 +814,13 @@ def create_model(self, model):
671
814
_validate_model (model )
672
815
try :
673
816
return self .handle_operation (
674
- self ._client .body ('post' , url = 'models' , json = model .as_dict ()))
817
+ self ._client .body ('post' , url = 'models' , json = model .as_dict (for_upload = True )))
675
818
except requests .exceptions .RequestException as error :
676
819
raise _utils .handle_platform_error_from_requests (error )
677
820
678
821
def update_model (self , model , update_mask = None ):
679
822
_validate_model (model , update_mask )
680
- data = {'model' : model .as_dict ()}
823
+ data = {'model' : model .as_dict (for_upload = True )}
681
824
if update_mask is not None :
682
825
data ['updateMask' ] = update_mask
683
826
try :
0 commit comments