@@ -524,11 +524,13 @@ def _tf_convert_from_keras_model(keras_model):
524
524
return converter .convert ()
525
525
526
526
@classmethod
527
- def from_saved_model (cls , saved_model_dir , bucket_name = None , app = None ):
527
+ def from_saved_model (cls , saved_model_dir , model_file_name = 'firebase_ml_model.tflite' ,
528
+ bucket_name = None , app = None ):
528
529
"""Creates a Tensor Flow Lite model from the saved model, and uploads the model to GCS.
529
530
530
531
Args:
531
532
saved_model_dir: The saved model directory.
533
+ model_file_name: The name that the tflite model will be saved as in Cloud Storage.
532
534
bucket_name: The name of an existing bucket. None to use the default bucket configured
533
535
in the app.
534
536
app: Optional. A Firebase app instance (or None to use the default app)
@@ -541,16 +543,18 @@ def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None):
541
543
"""
542
544
TFLiteGCSModelSource ._assert_tf_enabled ()
543
545
tflite_model = TFLiteGCSModelSource ._tf_convert_from_saved_model (saved_model_dir )
544
- open ('firebase_ml_model.tflite' , 'wb' ). write ( tflite_model )
545
- return TFLiteGCSModelSource . from_tflite_model_file (
546
- 'firebase_ml_model.tflite' , bucket_name , app )
546
+ with open (model_file_name , 'wb' ) as model_file :
547
+ model_file . write ( tflite_model )
548
+ return TFLiteGCSModelSource . from_tflite_model_file ( model_file_name , bucket_name , app )
547
549
548
550
@classmethod
549
- def from_keras_model (cls , keras_model , bucket_name = None , app = None ):
551
+ def from_keras_model (cls , keras_model , model_file_name = 'firebase_ml_model.tflite' ,
552
+ bucket_name = None , app = None ):
550
553
"""Creates a Tensor Flow Lite model from the keras model, and uploads the model to GCS.
551
554
552
555
Args:
553
556
keras_model: A tf.keras model.
557
+ model_file_name: The name that the tflite model will be saved as in Cloud Storage.
554
558
bucket_name: The name of an existing bucket. None to use the default bucket configured
555
559
in the app.
556
560
app: Optional. A Firebase app instance (or None to use the default app)
@@ -563,9 +567,9 @@ def from_keras_model(cls, keras_model, bucket_name=None, app=None):
563
567
"""
564
568
TFLiteGCSModelSource ._assert_tf_enabled ()
565
569
tflite_model = TFLiteGCSModelSource ._tf_convert_from_keras_model (keras_model )
566
- open ('firebase_ml_model.tflite' , 'wb' ). write ( tflite_model )
567
- return TFLiteGCSModelSource . from_tflite_model_file (
568
- 'firebase_ml_model.tflite' , bucket_name , app )
570
+ with open (model_file_name , 'wb' ) as model_file :
571
+ model_file . write ( tflite_model )
572
+ return TFLiteGCSModelSource . from_tflite_model_file ( model_file_name , bucket_name , app )
569
573
570
574
@property
571
575
def gcs_tflite_uri (self ):
0 commit comments