-
Notifications
You must be signed in to change notification settings - Fork 6k
[Model Card] standardize dreambooth model card #6729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fe67424
8c84d91
9bc1fbc
f463fe2
dc9afd3
108109c
458c37a
52d8131
07de49d
2f647dd
cc1c73e
2650fed
2c0a3cf
2f4e9b2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,7 @@ | |
import traceback | ||
import warnings | ||
from pathlib import Path | ||
from typing import Dict, Optional, Union | ||
from typing import Dict, List, Optional, Union | ||
from uuid import uuid4 | ||
|
||
from huggingface_hub import ( | ||
|
@@ -65,7 +65,7 @@ | |
|
||
logger = get_logger(__name__) | ||
|
||
|
||
MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "model_card_template.md" | ||
SESSION_ID = uuid4().hex | ||
|
||
|
||
|
@@ -94,43 +94,87 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: | |
|
||
|
||
def load_or_create_model_card( | ||
repo_id_or_path: Optional[str] = None, token: Optional[str] = None, is_pipeline: bool = False | ||
repo_id_or_path: str = None, | ||
token: Optional[str] = None, | ||
is_pipeline: bool = False, | ||
from_training: bool = False, | ||
model_description: Optional[str] = None, | ||
base_model: str = None, | ||
prompt: Optional[str] = None, | ||
license: Optional[str] = None, | ||
widget: Optional[List[dict]] = None, | ||
inference: Optional[bool] = None, | ||
) -> ModelCard: | ||
""" | ||
Loads or creates a model card. | ||
|
||
Args: | ||
repo_id (`str`): | ||
The repo_id where to look for the model card. | ||
repo_id_or_path (`str`): | ||
The repo id (e.g., "runwayml/stable-diffusion-v1-5") or local path where to look for the model card. | ||
token (`str`, *optional*): | ||
Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more details. | ||
is_pipeline (`bool`, *optional*): | ||
is_pipeline (`bool`): | ||
Boolean to indicate if we're adding tag to a [`DiffusionPipeline`]. | ||
from_training: (`bool`): Boolean flag to denote if the model card is being created from a training script. | ||
model_description (`str`, *optional*): Model description to add to the model card. Helpful when using | ||
`load_or_create_model_card` from a training script. | ||
base_model (`str`): Base model identifier (e.g., "stabilityai/stable-diffusion-xl-base-1.0"). Useful | ||
for DreamBooth-like training. | ||
prompt (`str`, *optional*): Prompt used for training. Useful for DreamBooth-like training. | ||
license: (`str`, *optional*): License of the output artifact. Helpful when using | ||
`load_or_create_model_card` from a training script. | ||
widget (`List[dict]`, *optional*): Widget to accompany a gallery template. | ||
inference: (`bool`, optional): Whether to turn on inference widget. Helpful when using | ||
`load_or_create_model_card` from a training script. | ||
""" | ||
if not is_jinja_available(): | ||
raise ValueError( | ||
"Modelcard rendering is based on Jinja templates." | ||
" Please make sure to have `jinja` installed before using `create_model_card`." | ||
" Please make sure to have `jinja` installed before using `load_or_create_model_card`." | ||
" To install it, please run `pip install Jinja2`." | ||
) | ||
|
||
try: | ||
# Check if the model card is present on the remote repo | ||
model_card = ModelCard.load(repo_id_or_path, token=token) | ||
except EntryNotFoundError: | ||
# Otherwise create a simple model card from template | ||
component = "pipeline" if is_pipeline else "model" | ||
model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated." | ||
card_data = ModelCardData() | ||
model_card = ModelCard.from_template(card_data, model_description=model_description) | ||
except (EntryNotFoundError, RepositoryNotFoundError): | ||
# Otherwise create a model card from template | ||
if from_training: | ||
model_card = ModelCard.from_template( | ||
card_data=ModelCardData( # Card metadata object that will be converted to YAML block | ||
license=license, | ||
library_name="diffusers", | ||
inference=inference, | ||
base_model=base_model, | ||
instance_prompt=prompt, | ||
widget=widget, | ||
), | ||
template_path=MODEL_CARD_TEMPLATE_PATH, | ||
model_description=model_description, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (maybe only if widget is not none and There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That means the existing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fine for me! |
||
) | ||
else: | ||
card_data = ModelCardData() | ||
component = "pipeline" if is_pipeline else "model" | ||
if model_description is None: | ||
model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated." | ||
model_card = ModelCard.from_template(card_data, model_description=model_description) | ||
|
||
return model_card | ||
|
||
|
||
def populate_model_card(model_card: ModelCard) -> ModelCard: | ||
"""Populates the `model_card` with library name.""" | ||
def populate_model_card(model_card: ModelCard, tags: Union[str, List[str]] = None) -> ModelCard: | ||
"""Populates the `model_card` with library name and optional tags.""" | ||
if model_card.data.library_name is None: | ||
model_card.data.library_name = "diffusers" | ||
|
||
if tags is not None: | ||
if isinstance(tags, str): | ||
tags = [tags] | ||
if model_card.data.tags is None: | ||
model_card.data.tags = [] | ||
for tag in tags: | ||
model_card.data.tags.append(tag) | ||
|
||
return model_card | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was having compulsion disorder :3 So, decided to fix these.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would even update to this 😄