diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 8e78a26ce..a135d2d7f 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -196,6 +196,7 @@ def __init__(self, display_name=None, tags=None, model_format=None): @classmethod def from_dict(cls, data, app=None): + """Create an instance of the object from a dict.""" data_copy = dict(data) tflite_format = None tflite_format_data = data_copy.pop('tfliteModel', None) @@ -223,6 +224,7 @@ def __ne__(self, other): @property def model_id(self): + """The model's ID, unique to the project.""" if not self._data.get('name'): return None _, model_id = _validate_and_parse_name(self._data.get('name')) @@ -230,6 +232,8 @@ def model_id(self): @property def display_name(self): + """The model's display name, used to refer to the model in code and in + the Firebase console.""" return self._data.get('displayName') @display_name.setter @@ -239,7 +243,7 @@ def display_name(self, display_name): @property def create_time(self): - """Returns the creation timestamp""" + """The time the model was created.""" seconds = self._data.get('createTime', {}).get('seconds') if not isinstance(seconds, numbers.Number): return None @@ -248,7 +252,7 @@ def create_time(self): @property def update_time(self): - """Returns the last update timestamp""" + """The time the model was last updated.""" seconds = self._data.get('updateTime', {}).get('seconds') if not isinstance(seconds, numbers.Number): return None @@ -257,22 +261,28 @@ def update_time(self): @property def validation_error(self): + """Validation error message.""" return self._data.get('state', {}).get('validationError', {}).get('message') @property def published(self): + """True if the model is published and available for clients to + download.""" return bool(self._data.get('state', {}).get('published')) @property def etag(self): + """The entity tag (ETag) of the model resource.""" return self._data.get('etag') @property def model_hash(self): + """SHA256 hash of the model binary.""" return self._data.get('modelHash') @property def tags(self): + """Tag strings, used for filtering query results.""" return self._data.get('tags') @tags.setter @@ -282,6 +292,7 @@ def tags(self, tags): @property def locked(self): + """True if the Model object is locked by an active operation.""" return bool(self._data.get('activeOperations') and len(self._data.get('activeOperations')) > 0) @@ -307,6 +318,8 @@ def wait_for_unlocked(self, max_time_seconds=None): @property def model_format(self): + """The model's ``ModelFormat`` object, which represents the model's + format and storage location.""" return self._model_format @model_format.setter @@ -317,6 +330,7 @@ def model_format(self, model_format): return self def as_dict(self, for_upload=False): + """Returns a serializable representation of the object.""" copy = dict(self._data) if self._model_format: copy.update(self._model_format.as_dict(for_upload=for_upload)) @@ -326,6 +340,7 @@ def as_dict(self, for_upload=False): class ModelFormat(object): """Abstract base class representing a Model Format such as TFLite.""" def as_dict(self, for_upload=False): + """Returns a serializable representation of the object.""" raise NotImplementedError @@ -344,6 +359,7 @@ def __init__(self, model_source=None): @classmethod def from_dict(cls, data): + """Create an instance of the object from a dict.""" data_copy = dict(data) model_source = None gcs_tflite_uri = data_copy.pop('gcsTfliteUri', None) @@ -366,6 +382,7 @@ def __ne__(self, other): @property def model_source(self): + """The TF Lite model's location.""" return self._model_source @model_source.setter @@ -377,9 +394,11 @@ def model_source(self, model_source): @property def size_bytes(self): + """The size in bytes of the TF Lite model.""" return self._data.get('sizeBytes') def as_dict(self, for_upload=False): + """Returns a serializable representation of the object.""" copy = dict(self._data) if self._model_source: copy.update(self._model_source.as_dict(for_upload=for_upload)) @@ -389,6 +408,7 @@ def as_dict(self, for_upload=False): class TFLiteModelSource(object): """Abstract base class representing a model source for TFLite format models.""" def as_dict(self, for_upload=False): + """Returns a serializable representation of the object.""" raise NotImplementedError @@ -415,6 +435,7 @@ def _parse_gcs_tflite_uri(uri): @staticmethod def upload(bucket_name, model_file_name, app): + """Upload a model file to the specified Storage bucket.""" _CloudStorageClient._assert_gcs_enabled() bucket = storage.bucket(bucket_name, app=app) blob_name = _CloudStorageClient.BLOB_NAME.format(model_file_name) @@ -531,6 +552,7 @@ def from_keras_model(cls, keras_model, bucket_name=None, app=None): @property def gcs_tflite_uri(self): + """URI of the model file in Cloud Storage.""" return self._gcs_tflite_uri @gcs_tflite_uri.setter @@ -542,6 +564,7 @@ def _get_signed_gcs_tflite_uri(self): return TFLiteGCSModelSource._STORAGE_CLIENT.sign_uri(self._gcs_tflite_uri, self._app) def as_dict(self, for_upload=False): + """Returns a serializable representation of the object.""" if for_upload: return {'gcsTfliteUri': self._get_signed_gcs_tflite_uri()} @@ -578,11 +601,12 @@ def list_filter(self): @property def next_page_token(self): + """Token identifying the next page of results.""" return self._list_response.get('nextPageToken', '') @property def has_next_page(self): - """A boolean indicating whether more pages are available.""" + """True if more pages are available.""" return bool(self.next_page_token) def get_next_page(self):