Skip to content

[Model Card] standardize dreambooth model card #6729

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 16 additions & 17 deletions examples/custom_diffusion/train_custom_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
)
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available


Expand All @@ -78,30 +79,28 @@ def save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"

yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
instance_prompt: {prompt}
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- diffusers
- custom-diffusion
inference: true
---
"""
model_card = f"""
model_description = f"""
# Custom Diffusion - {repo_id}

These are Custom Diffusion adaption weights for {base_model}. The weights were trained on {prompt} using [Custom Diffusion](https://www.cs.cmu.edu/~custom-diffusion). You can find some example images in the following. \n
{img_str}

\nFor more details on the training, please follow [this link](https://github.com/huggingface/diffusers/blob/main/examples/custom_diffusion).
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="creativeml-openrail-m",
base_model=base_model,
instance_prompt=prompt,
model_description=model_description,
inference=True,
)

tags = ["text-to-image", "diffusers", "stable-diffusion", "stable-diffusion-diffusers", "custom-diffusion"]
model_card = populate_model_card(model_card, tags=tags)

model_card.save(os.path.join(repo_folder, "README.md"))


def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
Expand Down
54 changes: 29 additions & 25 deletions examples/dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module

Expand All @@ -69,33 +70,20 @@

def save_model_card(
repo_id: str,
images=None,
base_model=str,
images: list = None,
base_model: str = None,
train_text_encoder=False,
prompt=str,
repo_folder=None,
prompt: str = None,
repo_folder: str = None,
pipeline: DiffusionPipeline = None,
):
img_str = ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"

yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
instance_prompt: {prompt}
tags:
- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'}
- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'}
- text-to-image
- diffusers
- dreambooth
inference: true
---
"""
model_card = f"""
if images is not None:
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"

model_description = f"""
# DreamBooth - {repo_id}

This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/).
Expand All @@ -104,8 +92,24 @@ def save_model_card(

DreamBooth for the text encoder was enabled: {train_text_encoder}.
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="creativeml-openrail-m",
base_model=base_model,
instance_prompt=prompt,
model_description=model_description,
inference=True,
)

tags = ["text-to-image", "dreambooth"]
if isinstance(pipeline, StableDiffusionPipeline):
tags.extend(["stable-diffusion", "stable-diffusion-diffusers"])
else:
tags.extend(["if", "if-diffusers"])
model_card = populate_model_card(model_card, tags=tags)

model_card.save(os.path.join(repo_folder, "README.md"))


def log_validation(
Expand Down
36 changes: 19 additions & 17 deletions examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
convert_unet_state_dict_to_peft,
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module

Expand All @@ -85,30 +86,31 @@ def save_model_card(
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"

yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
instance_prompt: {prompt}
tags:
- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'}
- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'}
- text-to-image
- diffusers
- lora
inference: true
---
"""
model_card = f"""
model_description = f"""
# LoRA DreamBooth - {repo_id}

These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
{img_str}

LoRA for the text encoder was enabled: {train_text_encoder}.
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="creativeml-openrail-m",
base_model=base_model,
instance_prompt=prompt,
model_description=model_description,
inference=True,
)
tags = ["text-to-image", "diffusers", "lora"]
if isinstance(pipeline, StableDiffusionPipeline):
tags.extend(["stable-diffusion", "stable-diffusion-diffusers"])
else:
tags.extend(["if", "if-diffusers"])
model_card = populate_model_card(model_card, tags=tags)

model_card.save(os.path.join(repo_folder, "README.md"))


def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
Expand Down
64 changes: 33 additions & 31 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
convert_unet_state_dict_to_peft,
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module

Expand All @@ -75,40 +76,22 @@
def save_model_card(
repo_id: str,
images=None,
base_model=str,
base_model: str = None,
train_text_encoder=False,
instance_prompt=str,
validation_prompt=str,
instance_prompt=None,
validation_prompt=None,
Comment on lines -78 to +82
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was having compulsion disorder :3 So, decided to fix these.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would even update to this 😄

    base_model: Optional[str] = None,
    train_text_encoder: bool = False,
    instance_prompt: Optional[str] = None,
    validation_prompt: Optional[str] = None,

repo_folder=None,
vae_path=None,
):
img_str = "widget:\n" if images else ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"""
- text: '{validation_prompt if validation_prompt else ' ' }'
output:
url:
"image_{i}.png"
"""

yaml = f"""
---
tags:
- stable-diffusion-xl
- stable-diffusion-xl-diffusers
- text-to-image
- diffusers
- lora
- template:sd-lora
{img_str}
base_model: {base_model}
instance_prompt: {instance_prompt}
license: openrail++
---
"""
widget_dict = []
if images is not None:
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
widget_dict.append(
{"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
)

model_card = f"""
model_description = f"""
# SDXL LoRA DreamBooth - {repo_id}

<Gallery />
Expand All @@ -134,8 +117,27 @@ def save_model_card(
[Download]({repo_id}/tree/main) them in the Files & versions tab.

"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="openrail++",
base_model=base_model,
instance_prompt=instance_prompt,
model_description=model_description,
widget=widget_dict,
)
tags = [
"text-to-image",
"stable-diffusion-xl",
"stable-diffusion-xl-diffusers",
"text-to-image",
"diffusers",
"lora",
"template:sd-lora",
]
model_card = populate_model_card(model_card, tags=tags)

model_card.save(os.path.join(repo_folder, "README.md"))


def import_model_class_from_model_name_or_path(
Expand Down
74 changes: 59 additions & 15 deletions src/diffusers/utils/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import traceback
import warnings
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Dict, List, Optional, Union
from uuid import uuid4

from huggingface_hub import (
Expand Down Expand Up @@ -65,7 +65,7 @@

logger = get_logger(__name__)


MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "model_card_template.md"
SESSION_ID = uuid4().hex


Expand Down Expand Up @@ -94,43 +94,87 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:


def load_or_create_model_card(
repo_id_or_path: Optional[str] = None, token: Optional[str] = None, is_pipeline: bool = False
repo_id_or_path: str = None,
token: Optional[str] = None,
is_pipeline: bool = False,
from_training: bool = False,
model_description: Optional[str] = None,
base_model: str = None,
prompt: Optional[str] = None,
license: Optional[str] = None,
widget: Optional[List[dict]] = None,
inference: Optional[bool] = None,
) -> ModelCard:
"""
Loads or creates a model card.

Args:
repo_id (`str`):
The repo_id where to look for the model card.
repo_id_or_path (`str`):
The repo id (e.g., "runwayml/stable-diffusion-v1-5") or local path where to look for the model card.
token (`str`, *optional*):
Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more details.
is_pipeline (`bool`, *optional*):
is_pipeline (`bool`):
Boolean to indicate if we're adding tag to a [`DiffusionPipeline`].
from_training: (`bool`): Boolean flag to denote if the model card is being created from a training script.
model_description (`str`, *optional*): Model description to add to the model card. Helpful when using
`load_or_create_model_card` from a training script.
base_model (`str`): Base model identifier (e.g., "stabilityai/stable-diffusion-xl-base-1.0"). Useful
for DreamBooth-like training.
prompt (`str`, *optional*): Prompt used for training. Useful for DreamBooth-like training.
license: (`str`, *optional*): License of the output artifact. Helpful when using
`load_or_create_model_card` from a training script.
widget (`List[dict]`, *optional*): Widget to accompany a gallery template.
inference: (`bool`, optional): Whether to turn on inference widget. Helpful when using
`load_or_create_model_card` from a training script.
"""
if not is_jinja_available():
raise ValueError(
"Modelcard rendering is based on Jinja templates."
" Please make sure to have `jinja` installed before using `create_model_card`."
" Please make sure to have `jinja` installed before using `load_or_create_model_card`."
" To install it, please run `pip install Jinja2`."
)

try:
# Check if the model card is present on the remote repo
model_card = ModelCard.load(repo_id_or_path, token=token)
except EntryNotFoundError:
# Otherwise create a simple model card from template
component = "pipeline" if is_pipeline else "model"
model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated."
card_data = ModelCardData()
model_card = ModelCard.from_template(card_data, model_description=model_description)
except (EntryNotFoundError, RepositoryNotFoundError):
# Otherwise create a model card from template
if from_training:
model_card = ModelCard.from_template(
card_data=ModelCardData( # Card metadata object that will be converted to YAML block
license=license,
library_name="diffusers",
inference=inference,
base_model=base_model,
instance_prompt=prompt,
widget=widget,
),
template_path=MODEL_CARD_TEMPLATE_PATH,
model_description=model_description,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If widget is populated, I would automatically add a </Gallery> tag to the model card description like done in https://huggingface.co/Pclanglais/Mickey-1928/blob/main/README.md?code=true. It will nicely render example images in a gallery (see here) and you wouldn't have to manually build a img_str in your notebook example. What do you think? :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(maybe only if widget is not none and "<gallery>/" not in model_description)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That means the existing save_model_card will have to be rewritten a bit to support compatibility with </Gallery>. I propose not to do that in this PR as it differs in scope.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine for me!

)
else:
card_data = ModelCardData()
component = "pipeline" if is_pipeline else "model"
if model_description is None:
model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated."
model_card = ModelCard.from_template(card_data, model_description=model_description)

return model_card


def populate_model_card(model_card: ModelCard) -> ModelCard:
"""Populates the `model_card` with library name."""
def populate_model_card(model_card: ModelCard, tags: Union[str, List[str]] = None) -> ModelCard:
"""Populates the `model_card` with library name and optional tags."""
if model_card.data.library_name is None:
model_card.data.library_name = "diffusers"

if tags is not None:
if isinstance(tags, str):
tags = [tags]
if model_card.data.tags is None:
model_card.data.tags = []
for tag in tags:
model_card.data.tags.append(tag)

return model_card


Expand Down
Loading