diff --git a/examples/t2i_adapter/train_t2i_adapter_sdxl.py b/examples/t2i_adapter/train_t2i_adapter_sdxl.py index 9cf23dbf4a32..5ebb2f3b0dea 100644 --- a/examples/t2i_adapter/train_t2i_adapter_sdxl.py +++ b/examples/t2i_adapter/train_t2i_adapter_sdxl.py @@ -49,6 +49,7 @@ ) from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module @@ -195,7 +196,7 @@ def import_model_class_from_model_name_or_path( raise ValueError(f"{model_class} is not supported.") -def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): +def save_model_card(repo_id: str, image_logs: dict = None, base_model: str = None, repo_folder: str = None): img_str = "" if image_logs is not None: img_str = "You can find some example images below.\n" @@ -209,27 +210,25 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) img_str += f"![images_{i})](./images_{i}.png)\n" - yaml = f""" ---- -license: creativeml-openrail-m -base_model: {base_model} -tags: -- stable-diffusion-xl -- stable-diffusion-xl-diffusers -- text-to-image -- diffusers -- t2iadapter -inference: true ---- - """ - model_card = f""" + model_description = f""" # t2iadapter-{repo_id} These are t2iadapter weights trained on {base_model} with new type of conditioning. {img_str} """ - with open(os.path.join(repo_folder, "README.md"), "w") as f: - f.write(yaml + model_card) + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="creativeml-openrail-m", + base_model=base_model, + model_description=model_description, + inference=True, + ) + + tags = ["stable-diffusion-xl", "stable-diffusion-xl-diffusers", "text-to-image", "diffusers", "t2iadapter"] + model_card = populate_model_card(model_card, tags=tags) + + model_card.save(os.path.join(repo_folder, "README.md")) def parse_args(input_args=None):