diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 4079fd14804b..7132e9521f79 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -146,21 +146,27 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No components[component].append(component_filename) # If there are no component folders check the main directory for safetensors files + filtered_filenames = set() if not components: if variant is not None: filtered_filenames = filter_with_regex(filenames, variant_file_re) - else: + + # If no variant filenames exist check if non-variant files are available + if not filtered_filenames: filtered_filenames = filter_with_regex(filenames, non_variant_file_re) return any(".safetensors" in filename for filename in filtered_filenames) # iterate over all files of a component # check if safetensor files exist for that component - # if variant is provided check if the variant of the safetensors exists for component, component_filenames in components.items(): matches = [] + filtered_component_filenames = set() + # if variant is provided check if the variant of the safetensors exists if variant is not None: filtered_component_filenames = filter_with_regex(component_filenames, variant_file_re) - else: + + # if variant safetensor files do not exist check for non-variants + if not filtered_component_filenames: filtered_component_filenames = filter_with_regex(component_filenames, non_variant_file_re) for component_filename in filtered_component_filenames: filename, extension = os.path.splitext(component_filename) diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index f680cf2dcf18..5154155447b5 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -217,6 +217,20 @@ def test_diffusers_is_compatible_no_components_only_variants(self): ] self.assertFalse(is_safetensors_compatible(filenames)) + def test_is_compatible_mixed_variants(self): + filenames = [ + "unet/diffusion_pytorch_model.fp16.safetensors", + "vae/diffusion_pytorch_model.safetensors", + ] + self.assertTrue(is_safetensors_compatible(filenames, variant="fp16")) + + def test_is_compatible_variant_and_non_safetensors(self): + filenames = [ + "unet/diffusion_pytorch_model.fp16.safetensors", + "vae/diffusion_pytorch_model.bin", + ] + self.assertFalse(is_safetensors_compatible(filenames, variant="fp16")) + class VariantCompatibleSiblingsTest(unittest.TestCase): def test_only_non_variants_downloaded(self): diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index a2241236da20..ef35ea2678db 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -538,38 +538,26 @@ def test_download_variant_partly(self): variant = "no_ema" with tempfile.TemporaryDirectory() as tmpdirname: - if use_safetensors: - with self.assertRaises(OSError) as error_context: - tmpdirname = StableDiffusionPipeline.download( - "hf-internal-testing/stable-diffusion-all-variants", - cache_dir=tmpdirname, - variant=variant, - use_safetensors=use_safetensors, - ) - assert "Could not find the necessary `safetensors` weights" in str(error_context.exception) - else: - tmpdirname = StableDiffusionPipeline.download( - "hf-internal-testing/stable-diffusion-all-variants", - cache_dir=tmpdirname, - variant=variant, - use_safetensors=use_safetensors, - ) - all_root_files = [t[-1] for t in os.walk(tmpdirname)] - files = [item for sublist in all_root_files for item in sublist] + tmpdirname = StableDiffusionPipeline.download( + "hf-internal-testing/stable-diffusion-all-variants", + cache_dir=tmpdirname, + variant=variant, + use_safetensors=use_safetensors, + ) + all_root_files = [t[-1] for t in os.walk(tmpdirname)] + files = [item for sublist in all_root_files for item in sublist] - unet_files = os.listdir(os.path.join(tmpdirname, "unet")) - - # Some of the downloaded files should be a non-variant file, check: - # https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet - assert len(files) == 15, f"We should only download 15 files, not {len(files)}" - # only unet has "no_ema" variant - assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files - assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1 - # vae, safety_checker and text_encoder should have no variant - assert ( - sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3 - ) - assert not any(f.endswith(other_format) for f in files) + unet_files = os.listdir(os.path.join(tmpdirname, "unet")) + + # Some of the downloaded files should be a non-variant file, check: + # https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet + assert len(files) == 15, f"We should only download 15 files, not {len(files)}" + # only unet has "no_ema" variant + assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files + assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1 + # vae, safety_checker and text_encoder should have no variant + assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3 + assert not any(f.endswith(other_format) for f in files) def test_download_variants_with_sharded_checkpoints(self): # Here we test for downloading of "variant" files belonging to the `unet` and