From bfc66f8aa0deadc1fd22c0766f50e51f0bc3949b Mon Sep 17 00:00:00 2001 From: DN6 Date: Mon, 26 May 2025 14:46:36 +0530 Subject: [PATCH 1/2] update --- src/diffusers/pipelines/pipeline_loading_utils.py | 12 +++++++++--- tests/pipelines/test_pipeline_utils.py | 7 +++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 4079fd14804b..7132e9521f79 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -146,21 +146,27 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No components[component].append(component_filename) # If there are no component folders check the main directory for safetensors files + filtered_filenames = set() if not components: if variant is not None: filtered_filenames = filter_with_regex(filenames, variant_file_re) - else: + + # If no variant filenames exist check if non-variant files are available + if not filtered_filenames: filtered_filenames = filter_with_regex(filenames, non_variant_file_re) return any(".safetensors" in filename for filename in filtered_filenames) # iterate over all files of a component # check if safetensor files exist for that component - # if variant is provided check if the variant of the safetensors exists for component, component_filenames in components.items(): matches = [] + filtered_component_filenames = set() + # if variant is provided check if the variant of the safetensors exists if variant is not None: filtered_component_filenames = filter_with_regex(component_filenames, variant_file_re) - else: + + # if variant safetensor files do not exist check for non-variants + if not filtered_component_filenames: filtered_component_filenames = filter_with_regex(component_filenames, non_variant_file_re) for component_filename in filtered_component_filenames: filename, extension = os.path.splitext(component_filename) diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index f680cf2dcf18..610ec9a06fe5 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -217,6 +217,13 @@ def test_diffusers_is_compatible_no_components_only_variants(self): ] self.assertFalse(is_safetensors_compatible(filenames)) + def test_is_compatible_mixed_variants(self): + filenames = [ + "unet/diffusion_pytorch_model.fp16.safetensors", + "vae/diffusion_pytorch_model.safetensors", + ] + self.assertTrue(is_safetensors_compatible(filenames, variant="fp16")) + class VariantCompatibleSiblingsTest(unittest.TestCase): def test_only_non_variants_downloaded(self): From 679e0958a23d3b363fe114b1339ffdf1cc2fb3dd Mon Sep 17 00:00:00 2001 From: DN6 Date: Mon, 26 May 2025 15:02:53 +0530 Subject: [PATCH 2/2] update --- tests/pipelines/test_pipeline_utils.py | 7 ++++ tests/pipelines/test_pipelines.py | 50 ++++++++++---------------- 2 files changed, 26 insertions(+), 31 deletions(-) diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index 610ec9a06fe5..5154155447b5 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -224,6 +224,13 @@ def test_is_compatible_mixed_variants(self): ] self.assertTrue(is_safetensors_compatible(filenames, variant="fp16")) + def test_is_compatible_variant_and_non_safetensors(self): + filenames = [ + "unet/diffusion_pytorch_model.fp16.safetensors", + "vae/diffusion_pytorch_model.bin", + ] + self.assertFalse(is_safetensors_compatible(filenames, variant="fp16")) + class VariantCompatibleSiblingsTest(unittest.TestCase): def test_only_non_variants_downloaded(self): diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index a2241236da20..ef35ea2678db 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -538,38 +538,26 @@ def test_download_variant_partly(self): variant = "no_ema" with tempfile.TemporaryDirectory() as tmpdirname: - if use_safetensors: - with self.assertRaises(OSError) as error_context: - tmpdirname = StableDiffusionPipeline.download( - "hf-internal-testing/stable-diffusion-all-variants", - cache_dir=tmpdirname, - variant=variant, - use_safetensors=use_safetensors, - ) - assert "Could not find the necessary `safetensors` weights" in str(error_context.exception) - else: - tmpdirname = StableDiffusionPipeline.download( - "hf-internal-testing/stable-diffusion-all-variants", - cache_dir=tmpdirname, - variant=variant, - use_safetensors=use_safetensors, - ) - all_root_files = [t[-1] for t in os.walk(tmpdirname)] - files = [item for sublist in all_root_files for item in sublist] + tmpdirname = StableDiffusionPipeline.download( + "hf-internal-testing/stable-diffusion-all-variants", + cache_dir=tmpdirname, + variant=variant, + use_safetensors=use_safetensors, + ) + all_root_files = [t[-1] for t in os.walk(tmpdirname)] + files = [item for sublist in all_root_files for item in sublist] - unet_files = os.listdir(os.path.join(tmpdirname, "unet")) - - # Some of the downloaded files should be a non-variant file, check: - # https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet - assert len(files) == 15, f"We should only download 15 files, not {len(files)}" - # only unet has "no_ema" variant - assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files - assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1 - # vae, safety_checker and text_encoder should have no variant - assert ( - sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3 - ) - assert not any(f.endswith(other_format) for f in files) + unet_files = os.listdir(os.path.join(tmpdirname, "unet")) + + # Some of the downloaded files should be a non-variant file, check: + # https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet + assert len(files) == 15, f"We should only download 15 files, not {len(files)}" + # only unet has "no_ema" variant + assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files + assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1 + # vae, safety_checker and text_encoder should have no variant + assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3 + assert not any(f.endswith(other_format) for f in files) def test_download_variants_with_sharded_checkpoints(self): # Here we test for downloading of "variant" files belonging to the `unet` and