27
27
import six
28
28
29
29
30
+ from six .moves import urllib
30
31
from firebase_admin import _http_client
31
32
from firebase_admin import _utils
32
33
from firebase_admin import exceptions
@@ -200,6 +201,7 @@ def from_dict(cls, data, app=None):
200
201
data_copy = dict (data )
201
202
tflite_format = None
202
203
tflite_format_data = data_copy .pop ('tfliteModel' , None )
204
+ data_copy .pop ('@type' , None ) # Returned by Operations. (Not needed)
203
205
if tflite_format_data :
204
206
tflite_format = TFLiteFormat .from_dict (tflite_format_data )
205
207
model = Model (model_format = tflite_format )
@@ -495,12 +497,31 @@ def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None):
495
497
return TFLiteGCSModelSource (gcs_tflite_uri = gcs_uri , app = app )
496
498
497
499
@staticmethod
498
- def _assert_tf_version_1_enabled ():
500
+ def _assert_tf_enabled ():
499
501
if not _TF_ENABLED :
500
502
raise ImportError ('Failed to import the tensorflow library for Python. Make sure '
501
503
'to install the tensorflow module.' )
502
- if not tf .VERSION .startswith ('1.' ):
503
- raise ImportError ('Expected tensorflow version 1.x, but found {0}' .format (tf .VERSION ))
504
+ if not tf .version .VERSION .startswith ('1.' ) and not tf .version .VERSION .startswith ('2.' ):
505
+ raise ImportError ('Expected tensorflow version 1.x or 2.x, but found {0}'
506
+ .format (tf .version .VERSION ))
507
+
508
+ @staticmethod
509
+ def _tf_convert_from_saved_model (saved_model_dir ):
510
+ # Same for both v1.x and v2.x
511
+ converter = tf .lite .TFLiteConverter .from_saved_model (saved_model_dir )
512
+ return converter .convert ()
513
+
514
+ @staticmethod
515
+ def _tf_convert_from_keras_model (keras_model ):
516
+ # Version 1.x conversion function takes a model file. Version 2.x takes the model itself.
517
+ if tf .version .VERSION .startswith ('1.' ):
518
+ keras_file = 'firebase_keras_model.h5'
519
+ tf .keras .models .save_model (keras_model , keras_file )
520
+ converter = tf .lite .TFLiteConverter .from_keras_model_file (keras_file )
521
+ return converter .convert ()
522
+ else :
523
+ converter = tf .lite .TFLiteConverter .from_keras_model (keras_model )
524
+ return converter .convert ()
504
525
505
526
@classmethod
506
527
def from_saved_model (cls , saved_model_dir , bucket_name = None , app = None ):
@@ -518,9 +539,8 @@ def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None):
518
539
Raises:
519
540
ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed.
520
541
"""
521
- TFLiteGCSModelSource ._assert_tf_version_1_enabled ()
522
- converter = tf .lite .TFLiteConverter .from_saved_model (saved_model_dir )
523
- tflite_model = converter .convert ()
542
+ TFLiteGCSModelSource ._assert_tf_enabled ()
543
+ tflite_model = TFLiteGCSModelSource ._tf_convert_from_saved_model (saved_model_dir )
524
544
open ('firebase_mlkit_model.tflite' , 'wb' ).write (tflite_model )
525
545
return TFLiteGCSModelSource .from_tflite_model_file (
526
546
'firebase_mlkit_model.tflite' , bucket_name , app )
@@ -541,11 +561,8 @@ def from_keras_model(cls, keras_model, bucket_name=None, app=None):
541
561
Raises:
542
562
ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed.
543
563
"""
544
- TFLiteGCSModelSource ._assert_tf_version_1_enabled ()
545
- keras_file = 'keras_model.h5'
546
- tf .keras .models .save_model (keras_model , keras_file )
547
- converter = tf .lite .TFLiteConverter .from_keras_model_file (keras_file )
548
- tflite_model = converter .convert ()
564
+ TFLiteGCSModelSource ._assert_tf_enabled ()
565
+ tflite_model = TFLiteGCSModelSource ._tf_convert_from_keras_model (keras_model )
549
566
open ('firebase_mlkit_model.tflite' , 'wb' ).write (tflite_model )
550
567
return TFLiteGCSModelSource .from_tflite_model_file (
551
568
'firebase_mlkit_model.tflite' , bucket_name , app )
@@ -852,12 +869,12 @@ def create_model(self, model):
852
869
853
870
def update_model (self , model , update_mask = None ):
854
871
_validate_model (model , update_mask )
855
- data = { 'model' : model . as_dict ( for_upload = True )}
872
+ path = 'models/{0}' . format ( model . model_id )
856
873
if update_mask is not None :
857
- data [ 'updateMask' ] = update_mask
874
+ path = path + '?updateMask={0}' . format ( update_mask )
858
875
try :
859
876
return self .handle_operation (
860
- self ._client .body ('patch' , url = 'models/{0}' . format ( model . model_id ) , json = data ))
877
+ self ._client .body ('patch' , url = path , json = model . as_dict ( for_upload = True ) ))
861
878
except requests .exceptions .RequestException as error :
862
879
raise _utils .handle_platform_error_from_requests (error )
863
880
@@ -884,15 +901,20 @@ def list_models(self, list_filter, page_size, page_token):
884
901
_validate_list_filter (list_filter )
885
902
_validate_page_size (page_size )
886
903
_validate_page_token (page_token )
887
- payload = {}
904
+ params = {}
888
905
if list_filter :
889
- payload [ 'list_filter ' ] = list_filter
906
+ params [ 'filter ' ] = list_filter
890
907
if page_size :
891
- payload ['page_size' ] = page_size
908
+ params ['page_size' ] = page_size
892
909
if page_token :
893
- payload ['page_token' ] = page_token
910
+ params ['page_token' ] = page_token
911
+ path = 'models'
912
+ if params :
913
+ # pylint: disable=too-many-function-args
914
+ param_str = urllib .parse .urlencode (sorted (params .items ()), True )
915
+ path = path + '?' + param_str
894
916
try :
895
- return self ._client .body ('get' , url = 'models' , json = payload )
917
+ return self ._client .body ('get' , url = path )
896
918
except requests .exceptions .RequestException as error :
897
919
raise _utils .handle_platform_error_from_requests (error )
898
920
0 commit comments