Skip to content

Commit 4e3e587

Browse files
committed
change str to dict
1 parent 4420806 commit 4e3e587

File tree

1 file changed

+12
-14
lines changed

1 file changed

+12
-14
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -90,20 +90,17 @@ def save_model_card(
9090
repo_folder=None,
9191
vae_path=None,
9292
):
93-
img_str = "widget:\n"
94-
lora = "lora" if not use_dora else "dora"
95-
for i, image in enumerate(images):
96-
image.save(os.path.join(repo_folder, f"image_{i}.png"))
97-
img_str += f"""
98-
- text: '{validation_prompt if validation_prompt else ' ' }'
99-
output:
100-
url:
101-
"image_{i}.png"
102-
"""
103-
if not images:
104-
img_str += f"""
105-
- text: '{instance_prompt}'
106-
"""
93+
widget_dict = []
94+
if images is not None:
95+
for i, image in enumerate(images):
96+
image.save(os.path.join(repo_folder, f"image_{i}.png"))
97+
widget_dict.append(
98+
{"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
99+
)
100+
else:
101+
widget_dict.append(
102+
{"text": instance_prompt}
103+
)
107104
embeddings_filename = f"{repo_folder}_emb"
108105
instance_prompt_webui = re.sub(r"<s\d+>", "", re.sub(r"<s\d+>", embeddings_filename, instance_prompt, count=1))
109106
ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"<s\d+>", instance_prompt))
@@ -194,6 +191,7 @@ def save_model_card(
194191
prompt=instance_prompt,
195192
model_description=model_description,
196193
inference=True,
194+
widget=widget_dict,
197195
)
198196

199197
tags = ["text-to-image",

0 commit comments

Comments
 (0)