Skip to content

Commit 76696dc

Browse files
authored
[Model Card] standardize dreambooth model card (#6729)
* feat: standarize model card creation for dreambooth training. * correct 'inference * remove comments. * take component out of kwargs * style * add: card template to have a leaner description. * widget support. * propagate changes to train_dreambooth_lora * propagate changes to custom diffusion * make widget properly type-annotated
1 parent 17612de commit 76696dc

File tree

6 files changed

+180
-105
lines changed

6 files changed

+180
-105
lines changed

examples/custom_diffusion/train_custom_diffusion.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
)
5959
from diffusers.optimization import get_scheduler
6060
from diffusers.utils import check_min_version, is_wandb_available
61+
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
6162
from diffusers.utils.import_utils import is_xformers_available
6263

6364

@@ -78,30 +79,28 @@ def save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_
7879
image.save(os.path.join(repo_folder, f"image_{i}.png"))
7980
img_str += f"![img_{i}](./image_{i}.png)\n"
8081

81-
yaml = f"""
82-
---
83-
license: creativeml-openrail-m
84-
base_model: {base_model}
85-
instance_prompt: {prompt}
86-
tags:
87-
- stable-diffusion
88-
- stable-diffusion-diffusers
89-
- text-to-image
90-
- diffusers
91-
- custom-diffusion
92-
inference: true
93-
---
94-
"""
95-
model_card = f"""
82+
model_description = f"""
9683
# Custom Diffusion - {repo_id}
9784
9885
These are Custom Diffusion adaption weights for {base_model}. The weights were trained on {prompt} using [Custom Diffusion](https://www.cs.cmu.edu/~custom-diffusion). You can find some example images in the following. \n
9986
{img_str}
10087
10188
\nFor more details on the training, please follow [this link](https://github.com/huggingface/diffusers/blob/main/examples/custom_diffusion).
10289
"""
103-
with open(os.path.join(repo_folder, "README.md"), "w") as f:
104-
f.write(yaml + model_card)
90+
model_card = load_or_create_model_card(
91+
repo_id_or_path=repo_id,
92+
from_training=True,
93+
license="creativeml-openrail-m",
94+
base_model=base_model,
95+
instance_prompt=prompt,
96+
model_description=model_description,
97+
inference=True,
98+
)
99+
100+
tags = ["text-to-image", "diffusers", "stable-diffusion", "stable-diffusion-diffusers", "custom-diffusion"]
101+
model_card = populate_model_card(model_card, tags=tags)
102+
103+
model_card.save(os.path.join(repo_folder, "README.md"))
105104

106105

107106
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):

examples/dreambooth/train_dreambooth.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from diffusers.optimization import get_scheduler
5555
from diffusers.training_utils import compute_snr
5656
from diffusers.utils import check_min_version, is_wandb_available
57+
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
5758
from diffusers.utils.import_utils import is_xformers_available
5859
from diffusers.utils.torch_utils import is_compiled_module
5960

@@ -69,33 +70,20 @@
6970

7071
def save_model_card(
7172
repo_id: str,
72-
images=None,
73-
base_model=str,
73+
images: list = None,
74+
base_model: str = None,
7475
train_text_encoder=False,
75-
prompt=str,
76-
repo_folder=None,
76+
prompt: str = None,
77+
repo_folder: str = None,
7778
pipeline: DiffusionPipeline = None,
7879
):
7980
img_str = ""
80-
for i, image in enumerate(images):
81-
image.save(os.path.join(repo_folder, f"image_{i}.png"))
82-
img_str += f"![img_{i}](./image_{i}.png)\n"
83-
84-
yaml = f"""
85-
---
86-
license: creativeml-openrail-m
87-
base_model: {base_model}
88-
instance_prompt: {prompt}
89-
tags:
90-
- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'}
91-
- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'}
92-
- text-to-image
93-
- diffusers
94-
- dreambooth
95-
inference: true
96-
---
97-
"""
98-
model_card = f"""
81+
if images is not None:
82+
for i, image in enumerate(images):
83+
image.save(os.path.join(repo_folder, f"image_{i}.png"))
84+
img_str += f"![img_{i}](./image_{i}.png)\n"
85+
86+
model_description = f"""
9987
# DreamBooth - {repo_id}
10088
10189
This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/).
@@ -104,8 +92,24 @@ def save_model_card(
10492
10593
DreamBooth for the text encoder was enabled: {train_text_encoder}.
10694
"""
107-
with open(os.path.join(repo_folder, "README.md"), "w") as f:
108-
f.write(yaml + model_card)
95+
model_card = load_or_create_model_card(
96+
repo_id_or_path=repo_id,
97+
from_training=True,
98+
license="creativeml-openrail-m",
99+
base_model=base_model,
100+
instance_prompt=prompt,
101+
model_description=model_description,
102+
inference=True,
103+
)
104+
105+
tags = ["text-to-image", "dreambooth"]
106+
if isinstance(pipeline, StableDiffusionPipeline):
107+
tags.extend(["stable-diffusion", "stable-diffusion-diffusers"])
108+
else:
109+
tags.extend(["if", "if-diffusers"])
110+
model_card = populate_model_card(model_card, tags=tags)
111+
112+
model_card.save(os.path.join(repo_folder, "README.md"))
109113

110114

111115
def log_validation(

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
convert_unet_state_dict_to_peft,
6262
is_wandb_available,
6363
)
64+
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
6465
from diffusers.utils.import_utils import is_xformers_available
6566
from diffusers.utils.torch_utils import is_compiled_module
6667

@@ -85,30 +86,31 @@ def save_model_card(
8586
image.save(os.path.join(repo_folder, f"image_{i}.png"))
8687
img_str += f"![img_{i}](./image_{i}.png)\n"
8788

88-
yaml = f"""
89-
---
90-
license: creativeml-openrail-m
91-
base_model: {base_model}
92-
instance_prompt: {prompt}
93-
tags:
94-
- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'}
95-
- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'}
96-
- text-to-image
97-
- diffusers
98-
- lora
99-
inference: true
100-
---
101-
"""
102-
model_card = f"""
89+
model_description = f"""
10390
# LoRA DreamBooth - {repo_id}
10491
10592
These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
10693
{img_str}
10794
10895
LoRA for the text encoder was enabled: {train_text_encoder}.
10996
"""
110-
with open(os.path.join(repo_folder, "README.md"), "w") as f:
111-
f.write(yaml + model_card)
97+
model_card = load_or_create_model_card(
98+
repo_id_or_path=repo_id,
99+
from_training=True,
100+
license="creativeml-openrail-m",
101+
base_model=base_model,
102+
instance_prompt=prompt,
103+
model_description=model_description,
104+
inference=True,
105+
)
106+
tags = ["text-to-image", "diffusers", "lora"]
107+
if isinstance(pipeline, StableDiffusionPipeline):
108+
tags.extend(["stable-diffusion", "stable-diffusion-diffusers"])
109+
else:
110+
tags.extend(["if", "if-diffusers"])
111+
model_card = populate_model_card(model_card, tags=tags)
112+
113+
model_card.save(os.path.join(repo_folder, "README.md"))
112114

113115

114116
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
convert_unet_state_dict_to_peft,
6363
is_wandb_available,
6464
)
65+
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
6566
from diffusers.utils.import_utils import is_xformers_available
6667
from diffusers.utils.torch_utils import is_compiled_module
6768

@@ -75,40 +76,22 @@
7576
def save_model_card(
7677
repo_id: str,
7778
images=None,
78-
base_model=str,
79+
base_model: str = None,
7980
train_text_encoder=False,
80-
instance_prompt=str,
81-
validation_prompt=str,
81+
instance_prompt=None,
82+
validation_prompt=None,
8283
repo_folder=None,
8384
vae_path=None,
8485
):
85-
img_str = "widget:\n" if images else ""
86-
for i, image in enumerate(images):
87-
image.save(os.path.join(repo_folder, f"image_{i}.png"))
88-
img_str += f"""
89-
- text: '{validation_prompt if validation_prompt else ' ' }'
90-
output:
91-
url:
92-
"image_{i}.png"
93-
"""
94-
95-
yaml = f"""
96-
---
97-
tags:
98-
- stable-diffusion-xl
99-
- stable-diffusion-xl-diffusers
100-
- text-to-image
101-
- diffusers
102-
- lora
103-
- template:sd-lora
104-
{img_str}
105-
base_model: {base_model}
106-
instance_prompt: {instance_prompt}
107-
license: openrail++
108-
---
109-
"""
86+
widget_dict = []
87+
if images is not None:
88+
for i, image in enumerate(images):
89+
image.save(os.path.join(repo_folder, f"image_{i}.png"))
90+
widget_dict.append(
91+
{"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
92+
)
11093

111-
model_card = f"""
94+
model_description = f"""
11295
# SDXL LoRA DreamBooth - {repo_id}
11396
11497
<Gallery />
@@ -134,8 +117,27 @@ def save_model_card(
134117
[Download]({repo_id}/tree/main) them in the Files & versions tab.
135118
136119
"""
137-
with open(os.path.join(repo_folder, "README.md"), "w") as f:
138-
f.write(yaml + model_card)
120+
model_card = load_or_create_model_card(
121+
repo_id_or_path=repo_id,
122+
from_training=True,
123+
license="openrail++",
124+
base_model=base_model,
125+
instance_prompt=instance_prompt,
126+
model_description=model_description,
127+
widget=widget_dict,
128+
)
129+
tags = [
130+
"text-to-image",
131+
"stable-diffusion-xl",
132+
"stable-diffusion-xl-diffusers",
133+
"text-to-image",
134+
"diffusers",
135+
"lora",
136+
"template:sd-lora",
137+
]
138+
model_card = populate_model_card(model_card, tags=tags)
139+
140+
model_card.save(os.path.join(repo_folder, "README.md"))
139141

140142

141143
def import_model_class_from_model_name_or_path(

src/diffusers/utils/hub_utils.py

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import traceback
2222
import warnings
2323
from pathlib import Path
24-
from typing import Dict, Optional, Union
24+
from typing import Dict, List, Optional, Union
2525
from uuid import uuid4
2626

2727
from huggingface_hub import (
@@ -65,7 +65,7 @@
6565

6666
logger = get_logger(__name__)
6767

68-
68+
MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "model_card_template.md"
6969
SESSION_ID = uuid4().hex
7070

7171

@@ -94,43 +94,87 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
9494

9595

9696
def load_or_create_model_card(
97-
repo_id_or_path: Optional[str] = None, token: Optional[str] = None, is_pipeline: bool = False
97+
repo_id_or_path: str = None,
98+
token: Optional[str] = None,
99+
is_pipeline: bool = False,
100+
from_training: bool = False,
101+
model_description: Optional[str] = None,
102+
base_model: str = None,
103+
prompt: Optional[str] = None,
104+
license: Optional[str] = None,
105+
widget: Optional[List[dict]] = None,
106+
inference: Optional[bool] = None,
98107
) -> ModelCard:
99108
"""
100109
Loads or creates a model card.
101110
102111
Args:
103-
repo_id (`str`):
104-
The repo_id where to look for the model card.
112+
repo_id_or_path (`str`):
113+
The repo id (e.g., "runwayml/stable-diffusion-v1-5") or local path where to look for the model card.
105114
token (`str`, *optional*):
106115
Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more details.
107-
is_pipeline (`bool`, *optional*):
116+
is_pipeline (`bool`):
108117
Boolean to indicate if we're adding tag to a [`DiffusionPipeline`].
118+
from_training: (`bool`): Boolean flag to denote if the model card is being created from a training script.
119+
model_description (`str`, *optional*): Model description to add to the model card. Helpful when using
120+
`load_or_create_model_card` from a training script.
121+
base_model (`str`): Base model identifier (e.g., "stabilityai/stable-diffusion-xl-base-1.0"). Useful
122+
for DreamBooth-like training.
123+
prompt (`str`, *optional*): Prompt used for training. Useful for DreamBooth-like training.
124+
license: (`str`, *optional*): License of the output artifact. Helpful when using
125+
`load_or_create_model_card` from a training script.
126+
widget (`List[dict]`, *optional*): Widget to accompany a gallery template.
127+
inference: (`bool`, optional): Whether to turn on inference widget. Helpful when using
128+
`load_or_create_model_card` from a training script.
109129
"""
110130
if not is_jinja_available():
111131
raise ValueError(
112132
"Modelcard rendering is based on Jinja templates."
113-
" Please make sure to have `jinja` installed before using `create_model_card`."
133+
" Please make sure to have `jinja` installed before using `load_or_create_model_card`."
114134
" To install it, please run `pip install Jinja2`."
115135
)
116136

117137
try:
118138
# Check if the model card is present on the remote repo
119139
model_card = ModelCard.load(repo_id_or_path, token=token)
120-
except EntryNotFoundError:
121-
# Otherwise create a simple model card from template
122-
component = "pipeline" if is_pipeline else "model"
123-
model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated."
124-
card_data = ModelCardData()
125-
model_card = ModelCard.from_template(card_data, model_description=model_description)
140+
except (EntryNotFoundError, RepositoryNotFoundError):
141+
# Otherwise create a model card from template
142+
if from_training:
143+
model_card = ModelCard.from_template(
144+
card_data=ModelCardData( # Card metadata object that will be converted to YAML block
145+
license=license,
146+
library_name="diffusers",
147+
inference=inference,
148+
base_model=base_model,
149+
instance_prompt=prompt,
150+
widget=widget,
151+
),
152+
template_path=MODEL_CARD_TEMPLATE_PATH,
153+
model_description=model_description,
154+
)
155+
else:
156+
card_data = ModelCardData()
157+
component = "pipeline" if is_pipeline else "model"
158+
if model_description is None:
159+
model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated."
160+
model_card = ModelCard.from_template(card_data, model_description=model_description)
126161

127162
return model_card
128163

129164

130-
def populate_model_card(model_card: ModelCard) -> ModelCard:
131-
"""Populates the `model_card` with library name."""
165+
def populate_model_card(model_card: ModelCard, tags: Union[str, List[str]] = None) -> ModelCard:
166+
"""Populates the `model_card` with library name and optional tags."""
132167
if model_card.data.library_name is None:
133168
model_card.data.library_name = "diffusers"
169+
170+
if tags is not None:
171+
if isinstance(tags, str):
172+
tags = [tags]
173+
if model_card.data.tags is None:
174+
model_card.data.tags = []
175+
for tag in tags:
176+
model_card.data.tags.append(tag)
177+
134178
return model_card
135179

136180

0 commit comments

Comments
 (0)