Skip to content

Commit a13d2a7

Browse files
authored
Adding File naming capability to from_saved_model and from_keras_model. (#375)
adding File naming capability for ModelSource
1 parent 079c7e1 commit a13d2a7

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

firebase_admin/ml.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -524,11 +524,13 @@ def _tf_convert_from_keras_model(keras_model):
524524
return converter.convert()
525525

526526
@classmethod
527-
def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None):
527+
def from_saved_model(cls, saved_model_dir, model_file_name='firebase_ml_model.tflite',
528+
bucket_name=None, app=None):
528529
"""Creates a Tensor Flow Lite model from the saved model, and uploads the model to GCS.
529530
530531
Args:
531532
saved_model_dir: The saved model directory.
533+
model_file_name: The name that the tflite model will be saved as in Cloud Storage.
532534
bucket_name: The name of an existing bucket. None to use the default bucket configured
533535
in the app.
534536
app: Optional. A Firebase app instance (or None to use the default app)
@@ -541,16 +543,18 @@ def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None):
541543
"""
542544
TFLiteGCSModelSource._assert_tf_enabled()
543545
tflite_model = TFLiteGCSModelSource._tf_convert_from_saved_model(saved_model_dir)
544-
open('firebase_ml_model.tflite', 'wb').write(tflite_model)
545-
return TFLiteGCSModelSource.from_tflite_model_file(
546-
'firebase_ml_model.tflite', bucket_name, app)
546+
with open(model_file_name, 'wb') as model_file:
547+
model_file.write(tflite_model)
548+
return TFLiteGCSModelSource.from_tflite_model_file(model_file_name, bucket_name, app)
547549

548550
@classmethod
549-
def from_keras_model(cls, keras_model, bucket_name=None, app=None):
551+
def from_keras_model(cls, keras_model, model_file_name='firebase_ml_model.tflite',
552+
bucket_name=None, app=None):
550553
"""Creates a Tensor Flow Lite model from the keras model, and uploads the model to GCS.
551554
552555
Args:
553556
keras_model: A tf.keras model.
557+
model_file_name: The name that the tflite model will be saved as in Cloud Storage.
554558
bucket_name: The name of an existing bucket. None to use the default bucket configured
555559
in the app.
556560
app: Optional. A Firebase app instance (or None to use the default app)
@@ -563,9 +567,9 @@ def from_keras_model(cls, keras_model, bucket_name=None, app=None):
563567
"""
564568
TFLiteGCSModelSource._assert_tf_enabled()
565569
tflite_model = TFLiteGCSModelSource._tf_convert_from_keras_model(keras_model)
566-
open('firebase_ml_model.tflite', 'wb').write(tflite_model)
567-
return TFLiteGCSModelSource.from_tflite_model_file(
568-
'firebase_ml_model.tflite', bucket_name, app)
570+
with open(model_file_name, 'wb') as model_file:
571+
model_file.write(tflite_model)
572+
return TFLiteGCSModelSource.from_tflite_model_file(model_file_name, bucket_name, app)
569573

570574
@property
571575
def gcs_tflite_uri(self):

0 commit comments

Comments
 (0)