diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 5eba1952e608..0a7a222ec007 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -198,10 +198,31 @@ def convert_to_variant(filename): variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}" return variant_filename - for f in non_variant_filenames: - variant_filename = convert_to_variant(f) - if variant_filename not in usable_filenames: - usable_filenames.add(f) + def find_component(filename): + if not len(filename.split("/")) == 2: + return + component = filename.split("/")[0] + return component + + def has_sharded_variant(component, variant, variant_filenames): + # If component exists check for sharded variant index filename + # If component doesn't exist check main dir for sharded variant index filename + component = component + "/" if component else "" + variant_index_re = re.compile( + rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$" + ) + return any(f for f in variant_filenames if variant_index_re.match(f) is not None) + + for filename in non_variant_filenames: + if convert_to_variant(filename) in variant_filenames: + continue + + component = find_component(filename) + # If a sharded variant exists skip adding to allowed patterns + if has_sharded_variant(component, variant, variant_filenames): + continue + + usable_filenames.add(filename) return usable_filenames, variant_filenames diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index bb3bdc273cc4..acf7d9d8401b 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -18,7 +18,7 @@ StableDiffusionPipeline, UNet2DConditionModel, ) -from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible +from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings from diffusers.utils.testing_utils import torch_device @@ -210,6 +210,135 @@ def test_diffusers_is_compatible_no_components_only_variants(self): self.assertFalse(is_safetensors_compatible(filenames)) +class VariantCompatibleSiblingsTest(unittest.TestCase): + def test_only_non_variants_downloaded(self): + variant = "fp16" + filenames = [ + f"vae/diffusion_pytorch_model.{variant}.safetensors", + "vae/diffusion_pytorch_model.safetensors", + f"text_encoder/model.{variant}.safetensors", + "text_encoder/model.safetensors", + f"unet/diffusion_pytorch_model.{variant}.safetensors", + "unet/diffusion_pytorch_model.safetensors", + ] + + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None) + assert all(variant not in f for f in model_filenames) + + def test_only_variants_downloaded(self): + variant = "fp16" + filenames = [ + f"vae/diffusion_pytorch_model.{variant}.safetensors", + "vae/diffusion_pytorch_model.safetensors", + f"text_encoder/model.{variant}.safetensors", + "text_encoder/model.safetensors", + f"unet/diffusion_pytorch_model.{variant}.safetensors", + "unet/diffusion_pytorch_model.safetensors", + ] + + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + assert all(variant in f for f in model_filenames) + + def test_mixed_variants_downloaded(self): + variant = "fp16" + non_variant_file = "text_encoder/model.safetensors" + filenames = [ + f"vae/diffusion_pytorch_model.{variant}.safetensors", + "vae/diffusion_pytorch_model.safetensors", + "text_encoder/model.safetensors", + f"unet/diffusion_pytorch_model.{variant}.safetensors", + "unet/diffusion_pytorch_model.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames) + + def test_non_variants_in_main_dir_downloaded(self): + variant = "fp16" + filenames = [ + f"diffusion_pytorch_model.{variant}.safetensors", + "diffusion_pytorch_model.safetensors", + "model.safetensors", + f"model.{variant}.safetensors", + f"diffusion_pytorch_model.{variant}.safetensors", + "diffusion_pytorch_model.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None) + assert all(variant not in f for f in model_filenames) + + def test_variants_in_main_dir_downloaded(self): + variant = "fp16" + filenames = [ + f"diffusion_pytorch_model.{variant}.safetensors", + "diffusion_pytorch_model.safetensors", + "model.safetensors", + f"model.{variant}.safetensors", + f"diffusion_pytorch_model.{variant}.safetensors", + "diffusion_pytorch_model.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + assert all(variant in f for f in model_filenames) + + def test_mixed_variants_in_main_dir_downloaded(self): + variant = "fp16" + non_variant_file = "model.safetensors" + filenames = [ + f"diffusion_pytorch_model.{variant}.safetensors", + "diffusion_pytorch_model.safetensors", + "model.safetensors", + f"diffusion_pytorch_model.{variant}.safetensors", + "diffusion_pytorch_model.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames) + + def test_sharded_non_variants_downloaded(self): + variant = "fp16" + filenames = [ + f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json", + "unet/diffusion_pytorch_model.safetensors.index.json", + "unet/diffusion_pytorch_model-00001-of-00003.safetensors", + "unet/diffusion_pytorch_model-00002-of-00003.safetensors", + "unet/diffusion_pytorch_model-00003-of-00003.safetensors", + f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors", + f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None) + assert all(variant not in f for f in model_filenames) + + def test_sharded_variants_downloaded(self): + variant = "fp16" + filenames = [ + f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json", + "unet/diffusion_pytorch_model.safetensors.index.json", + "unet/diffusion_pytorch_model-00001-of-00003.safetensors", + "unet/diffusion_pytorch_model-00002-of-00003.safetensors", + "unet/diffusion_pytorch_model-00003-of-00003.safetensors", + f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors", + f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + assert all(variant in f for f in model_filenames) + + def test_sharded_mixed_variants_downloaded(self): + variant = "fp16" + allowed_non_variant = "unet" + filenames = [ + f"vae/diffusion_pytorch_model.safetensors.index.{variant}.json", + "vae/diffusion_pytorch_model.safetensors.index.json", + "unet/diffusion_pytorch_model.safetensors.index.json", + "unet/diffusion_pytorch_model-00001-of-00003.safetensors", + "unet/diffusion_pytorch_model-00002-of-00003.safetensors", + "unet/diffusion_pytorch_model-00003-of-00003.safetensors", + f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors", + f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors", + "vae/diffusion_pytorch_model-00001-of-00003.safetensors", + "vae/diffusion_pytorch_model-00002-of-00003.safetensors", + "vae/diffusion_pytorch_model-00003-of-00003.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames) + + class ProgressBarTests(unittest.TestCase): def get_dummy_components_image_generation(self): cross_attention_dim = 8