23
23
import re
24
24
import time
25
25
import os
26
- import requests
27
- import six
26
+ from urllib import parse
28
27
28
+ import requests
29
29
30
- from six .moves import urllib
31
30
from firebase_admin import _http_client
32
31
from firebase_admin import _utils
33
32
from firebase_admin import exceptions
@@ -175,7 +174,7 @@ def delete_model(model_id, app=None):
175
174
ml_service .delete_model (model_id )
176
175
177
176
178
- class Model ( object ) :
177
+ class Model :
179
178
"""A Firebase ML Model object.
180
179
181
180
Args:
@@ -218,8 +217,7 @@ def __eq__(self, other):
218
217
if isinstance (other , self .__class__ ):
219
218
# pylint: disable=protected-access
220
219
return self ._data == other ._data and self ._model_format == other ._model_format
221
- else :
222
- return False
220
+ return False
223
221
224
222
def __ne__ (self , other ):
225
223
return not self .__eq__ (other )
@@ -341,7 +339,7 @@ def as_dict(self, for_upload=False):
341
339
return copy
342
340
343
341
344
- class ModelFormat ( object ) :
342
+ class ModelFormat :
345
343
"""Abstract base class representing a Model Format such as TFLite."""
346
344
def as_dict (self , for_upload = False ):
347
345
"""Returns a serializable representation of the object."""
@@ -378,8 +376,7 @@ def __eq__(self, other):
378
376
if isinstance (other , self .__class__ ):
379
377
# pylint: disable=protected-access
380
378
return self ._data == other ._data and self ._model_source == other ._model_source
381
- else :
382
- return False
379
+ return False
383
380
384
381
def __ne__ (self , other ):
385
382
return not self .__eq__ (other )
@@ -409,14 +406,14 @@ def as_dict(self, for_upload=False):
409
406
return {'tfliteModel' : copy }
410
407
411
408
412
- class TFLiteModelSource ( object ) :
409
+ class TFLiteModelSource :
413
410
"""Abstract base class representing a model source for TFLite format models."""
414
411
def as_dict (self , for_upload = False ):
415
412
"""Returns a serializable representation of the object."""
416
413
raise NotImplementedError
417
414
418
415
419
- class _CloudStorageClient ( object ) :
416
+ class _CloudStorageClient :
420
417
"""Cloud Storage helper class"""
421
418
422
419
GCS_URI = 'gs://{0}/{1}'
@@ -475,8 +472,7 @@ def __init__(self, gcs_tflite_uri, app=None):
475
472
def __eq__ (self , other ):
476
473
if isinstance (other , self .__class__ ):
477
474
return self ._gcs_tflite_uri == other ._gcs_tflite_uri # pylint: disable=protected-access
478
- else :
479
- return False
475
+ return False
480
476
481
477
def __ne__ (self , other ):
482
478
return not self .__eq__ (other )
@@ -517,15 +513,16 @@ def _tf_convert_from_saved_model(saved_model_dir):
517
513
518
514
@staticmethod
519
515
def _tf_convert_from_keras_model (keras_model ):
516
+ """Converts the given Keras model into a TF Lite model."""
520
517
# Version 1.x conversion function takes a model file. Version 2.x takes the model itself.
521
518
if tf .version .VERSION .startswith ('1.' ):
522
519
keras_file = 'firebase_keras_model.h5'
523
520
tf .keras .models .save_model (keras_model , keras_file )
524
521
converter = tf .lite .TFLiteConverter .from_keras_model_file (keras_file )
525
- return converter .convert ()
526
522
else :
527
523
converter = tf .lite .TFLiteConverter .from_keras_model (keras_model )
528
- return converter .convert ()
524
+
525
+ return converter .convert ()
529
526
530
527
@classmethod
531
528
def from_saved_model (cls , saved_model_dir , model_file_name = 'firebase_ml_model.tflite' ,
@@ -596,7 +593,7 @@ def as_dict(self, for_upload=False):
596
593
return {'gcsTfliteUri' : self ._gcs_tflite_uri }
597
594
598
595
599
- class ListModelsPage ( object ) :
596
+ class ListModelsPage :
600
597
"""Represents a page of models in a firebase project.
601
598
602
599
Provides methods for traversing the models included in this page, as well as
@@ -662,7 +659,7 @@ def iterate_all(self):
662
659
return _ModelIterator (self )
663
660
664
661
665
- class _ModelIterator ( object ) :
662
+ class _ModelIterator :
666
663
"""An iterator that allows iterating over models, one at a time.
667
664
668
665
This implementation loads a page of models into memory, and iterates on them.
@@ -730,7 +727,7 @@ def _validate_display_name(display_name):
730
727
731
728
def _validate_tags (tags ):
732
729
if not isinstance (tags , list ) or not \
733
- all (isinstance (tag , six . string_types ) for tag in tags ):
730
+ all (isinstance (tag , str ) for tag in tags ):
734
731
raise TypeError ('Tags must be a list of strings.' )
735
732
if not all (_TAG_PATTERN .match (tag ) for tag in tags ):
736
733
raise ValueError ('Tag format is invalid.' )
@@ -753,7 +750,7 @@ def _validate_model_format(model_format):
753
750
754
751
def _validate_list_filter (list_filter ):
755
752
if list_filter is not None :
756
- if not isinstance (list_filter , six . string_types ):
753
+ if not isinstance (list_filter , str ):
757
754
raise TypeError ('List filter must be a string or None.' )
758
755
759
756
@@ -769,11 +766,11 @@ def _validate_page_size(page_size):
769
766
770
767
def _validate_page_token (page_token ):
771
768
if page_token is not None :
772
- if not isinstance (page_token , six . string_types ):
769
+ if not isinstance (page_token , str ):
773
770
raise TypeError ('Page token must be a string or None.' )
774
771
775
772
776
- class _MLService ( object ) :
773
+ class _MLService :
777
774
"""Firebase ML service."""
778
775
779
776
PROJECT_URL = 'https://mlkit.googleapis.com/v1beta1/projects/{0}/'
@@ -811,8 +808,7 @@ def _exponential_backoff(self, current_attempt, stop_time):
811
808
max_seconds_left = (stop_time - datetime .datetime .now ()).total_seconds ()
812
809
if max_seconds_left < 1 : # allow a bit of time for rpc
813
810
raise exceptions .DeadlineExceededError ('Polling max time exceeded.' )
814
- else :
815
- wait_time_seconds = min (wait_time_seconds , max_seconds_left - 1 )
811
+ wait_time_seconds = min (wait_time_seconds , max_seconds_left - 1 )
816
812
time .sleep (wait_time_seconds )
817
813
818
814
def handle_operation (self , operation , wait_for_operation = False , max_time_seconds = None ):
@@ -831,6 +827,7 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds
831
827
Raises:
832
828
TypeError: if the operation is not a dictionary.
833
829
ValueError: If the operation is malformed.
830
+ UnknownError: If the server responds with an unexpected response.
834
831
err: If the operation exceeds polling attempts or stop_time
835
832
"""
836
833
if not isinstance (operation , dict ):
@@ -840,31 +837,31 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds
840
837
# Operations which are immediately done don't have an operation name
841
838
if operation .get ('response' ):
842
839
return operation .get ('response' )
843
- elif operation .get ('error' ):
840
+ if operation .get ('error' ):
844
841
raise _utils .handle_operation_error (operation .get ('error' ))
845
842
raise exceptions .UnknownError (message = 'Internal Error: Malformed Operation.' )
846
- else :
847
- op_name = operation .get ('name' )
848
- _ , model_id = _validate_and_parse_operation_name (op_name )
849
- current_attempt = 0
850
- start_time = datetime .datetime .now ()
851
- stop_time = (None if max_time_seconds is None else
852
- start_time + datetime .timedelta (seconds = max_time_seconds ))
853
- while wait_for_operation and not operation .get ('done' ):
854
- # We just got this operation. Wait before getting another
855
- # so we don't exceed the GetOperation maximum request rate.
856
- self ._exponential_backoff (current_attempt , stop_time )
857
- operation = self .get_operation (op_name )
858
- current_attempt += 1
859
-
860
- if operation .get ('done' ):
861
- if operation .get ('response' ):
862
- return operation .get ('response' )
863
- elif operation .get ('error' ):
864
- raise _utils .handle_operation_error (operation .get ('error' ))
865
-
866
- # If the operation is not complete or timed out, return a (locked) model instead
867
- return get_model (model_id ).as_dict ()
843
+
844
+ op_name = operation .get ('name' )
845
+ _ , model_id = _validate_and_parse_operation_name (op_name )
846
+ current_attempt = 0
847
+ start_time = datetime .datetime .now ()
848
+ stop_time = (None if max_time_seconds is None else
849
+ start_time + datetime .timedelta (seconds = max_time_seconds ))
850
+ while wait_for_operation and not operation .get ('done' ):
851
+ # We just got this operation. Wait before getting another
852
+ # so we don't exceed the GetOperation maximum request rate.
853
+ self ._exponential_backoff (current_attempt , stop_time )
854
+ operation = self .get_operation (op_name )
855
+ current_attempt += 1
856
+
857
+ if operation .get ('done' ):
858
+ if operation .get ('response' ):
859
+ return operation .get ('response' )
860
+ if operation .get ('error' ):
861
+ raise _utils .handle_operation_error (operation .get ('error' ))
862
+
863
+ # If the operation is not complete or timed out, return a (locked) model instead
864
+ return get_model (model_id ).as_dict ()
868
865
869
866
870
867
def create_model (self , model ):
@@ -918,8 +915,7 @@ def list_models(self, list_filter, page_size, page_token):
918
915
params ['page_token' ] = page_token
919
916
path = 'models'
920
917
if params :
921
- # pylint: disable=too-many-function-args
922
- param_str = urllib .parse .urlencode (sorted (params .items ()), True )
918
+ param_str = parse .urlencode (sorted (params .items ()), True )
923
919
path = path + '?' + param_str
924
920
try :
925
921
return self ._client .body ('get' , url = path )
0 commit comments