diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 38b6e8dab209..f8253715e64d 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -71,6 +71,7 @@ convert_unet_state_dict_to_peft, is_wandb_available, ) +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module @@ -101,7 +102,7 @@ def determine_scheduler_type(pretrained_model_name_or_path, revision): def save_model_card( repo_id: str, use_dora: bool, - images=None, + images: list = None, base_model: str = None, train_text_encoder=False, train_text_encoder_ti=False, @@ -111,20 +112,17 @@ def save_model_card( repo_folder=None, vae_path=None, ): - img_str = "widget:\n" lora = "lora" if not use_dora else "dora" - for i, image in enumerate(images): - image.save(os.path.join(repo_folder, f"image_{i}.png")) - img_str += f""" - - text: '{validation_prompt if validation_prompt else ' ' }' - output: - url: - "image_{i}.png" - """ - if not images: - img_str += f""" - - text: '{instance_prompt}' - """ + + widget_dict = [] + if images is not None: + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + widget_dict.append( + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} + ) + else: + widget_dict.append({"text": instance_prompt}) embeddings_filename = f"{repo_folder}_emb" instance_prompt_webui = re.sub(r"", "", re.sub(r"", embeddings_filename, instance_prompt, count=1)) ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"", instance_prompt)) @@ -169,23 +167,7 @@ def save_model_card( to trigger concept `{key}` → use `{tokens}` in your prompt \n """ - yaml = f"""--- -tags: -- stable-diffusion-xl -- stable-diffusion-xl-diffusers -- diffusers-training -- text-to-image -- diffusers -- {lora} -- template:sd-lora -{img_str} -base_model: {base_model} -instance_prompt: {instance_prompt} -license: openrail++ ---- -""" - - model_card = f""" + model_description = f""" # SDXL LoRA DreamBooth - {repo_id} @@ -234,8 +216,25 @@ def save_model_card( {license} """ - with open(os.path.join(repo_folder, "README.md"), "w") as f: - f.write(yaml + model_card) + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="openrail++", + base_model=base_model, + prompt=instance_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-image", + "stable-diffusion-xl", + "stable-diffusion-xl-diffusers", + "text-to-image", + "diffusers", + lora, + "template:sd-lora", + ] + model_card = populate_model_card(model_card, tags=tags) def log_validation(