From 7b668b1de122964774e178b056a30c43dc806fab Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 5 Nov 2024 18:54:03 +0100 Subject: [PATCH 1/4] update --- src/diffusers/pipelines/pipeline_loading_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 5eba1952e608..f183d55dc759 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -198,10 +198,12 @@ def convert_to_variant(filename): variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}" return variant_filename + components_with_variant = {filename.split("/")[0] for filename in variant_filenames} for f in non_variant_filenames: - variant_filename = convert_to_variant(f) - if variant_filename not in usable_filenames: - usable_filenames.add(f) + component, component_filename = f.split("/") + if component in components_with_variant: + continue + usable_filenames.add(f) return usable_filenames, variant_filenames From ed0a6d70c5045e602f943577b2a2b0899f65ad4d Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 5 Nov 2024 20:50:29 +0100 Subject: [PATCH 2/4] update --- src/diffusers/pipelines/pipeline_loading_utils.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index f183d55dc759..c070f8990dd1 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -198,12 +198,21 @@ def convert_to_variant(filename): variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}" return variant_filename - components_with_variant = {filename.split("/")[0] for filename in variant_filenames} + components_with_variant = set() + for filename in variant_filenames: + if not len(filename.split("/")) == 2: + continue + component, component_filename = filename.split("/") + components_with_variant.add(component) + for f in non_variant_filenames: - component, component_filename = f.split("/") + component = f.split("/")[0] + # If a component already has a variant skip including any other files if component in components_with_variant: continue - usable_filenames.add(f) + # If a variant version of a file doesn't exist add the file to the allowed patterns list + if convert_to_variant(f) not in variant_filenames: + usable_filenames.add(f) return usable_filenames, variant_filenames From f3b76fb4301aa80426c386d246a2b0cb993a33e6 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 6 Nov 2024 19:47:15 +0100 Subject: [PATCH 3/4] update --- .../pipelines/pipeline_loading_utils.py | 32 ++++++++++++------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index c070f8990dd1..0a7a222ec007 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -198,21 +198,31 @@ def convert_to_variant(filename): variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}" return variant_filename - components_with_variant = set() - for filename in variant_filenames: + def find_component(filename): if not len(filename.split("/")) == 2: + return + component = filename.split("/")[0] + return component + + def has_sharded_variant(component, variant, variant_filenames): + # If component exists check for sharded variant index filename + # If component doesn't exist check main dir for sharded variant index filename + component = component + "/" if component else "" + variant_index_re = re.compile( + rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$" + ) + return any(f for f in variant_filenames if variant_index_re.match(f) is not None) + + for filename in non_variant_filenames: + if convert_to_variant(filename) in variant_filenames: continue - component, component_filename = filename.split("/") - components_with_variant.add(component) - for f in non_variant_filenames: - component = f.split("/")[0] - # If a component already has a variant skip including any other files - if component in components_with_variant: + component = find_component(filename) + # If a sharded variant exists skip adding to allowed patterns + if has_sharded_variant(component, variant, variant_filenames): continue - # If a variant version of a file doesn't exist add the file to the allowed patterns list - if convert_to_variant(f) not in variant_filenames: - usable_filenames.add(f) + + usable_filenames.add(filename) return usable_filenames, variant_filenames From b0caf4169ff8855cad834a4349867d090658efc6 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 8 Nov 2024 12:36:29 +0100 Subject: [PATCH 4/4] update --- tests/pipelines/test_pipeline_utils.py | 131 ++++++++++++++++++++++++- 1 file changed, 130 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index bb3bdc273cc4..acf7d9d8401b 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -18,7 +18,7 @@ StableDiffusionPipeline, UNet2DConditionModel, ) -from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible +from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings from diffusers.utils.testing_utils import torch_device @@ -210,6 +210,135 @@ def test_diffusers_is_compatible_no_components_only_variants(self): self.assertFalse(is_safetensors_compatible(filenames)) +class VariantCompatibleSiblingsTest(unittest.TestCase): + def test_only_non_variants_downloaded(self): + variant = "fp16" + filenames = [ + f"vae/diffusion_pytorch_model.{variant}.safetensors", + "vae/diffusion_pytorch_model.safetensors", + f"text_encoder/model.{variant}.safetensors", + "text_encoder/model.safetensors", + f"unet/diffusion_pytorch_model.{variant}.safetensors", + "unet/diffusion_pytorch_model.safetensors", + ] + + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None) + assert all(variant not in f for f in model_filenames) + + def test_only_variants_downloaded(self): + variant = "fp16" + filenames = [ + f"vae/diffusion_pytorch_model.{variant}.safetensors", + "vae/diffusion_pytorch_model.safetensors", + f"text_encoder/model.{variant}.safetensors", + "text_encoder/model.safetensors", + f"unet/diffusion_pytorch_model.{variant}.safetensors", + "unet/diffusion_pytorch_model.safetensors", + ] + + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + assert all(variant in f for f in model_filenames) + + def test_mixed_variants_downloaded(self): + variant = "fp16" + non_variant_file = "text_encoder/model.safetensors" + filenames = [ + f"vae/diffusion_pytorch_model.{variant}.safetensors", + "vae/diffusion_pytorch_model.safetensors", + "text_encoder/model.safetensors", + f"unet/diffusion_pytorch_model.{variant}.safetensors", + "unet/diffusion_pytorch_model.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames) + + def test_non_variants_in_main_dir_downloaded(self): + variant = "fp16" + filenames = [ + f"diffusion_pytorch_model.{variant}.safetensors", + "diffusion_pytorch_model.safetensors", + "model.safetensors", + f"model.{variant}.safetensors", + f"diffusion_pytorch_model.{variant}.safetensors", + "diffusion_pytorch_model.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None) + assert all(variant not in f for f in model_filenames) + + def test_variants_in_main_dir_downloaded(self): + variant = "fp16" + filenames = [ + f"diffusion_pytorch_model.{variant}.safetensors", + "diffusion_pytorch_model.safetensors", + "model.safetensors", + f"model.{variant}.safetensors", + f"diffusion_pytorch_model.{variant}.safetensors", + "diffusion_pytorch_model.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + assert all(variant in f for f in model_filenames) + + def test_mixed_variants_in_main_dir_downloaded(self): + variant = "fp16" + non_variant_file = "model.safetensors" + filenames = [ + f"diffusion_pytorch_model.{variant}.safetensors", + "diffusion_pytorch_model.safetensors", + "model.safetensors", + f"diffusion_pytorch_model.{variant}.safetensors", + "diffusion_pytorch_model.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames) + + def test_sharded_non_variants_downloaded(self): + variant = "fp16" + filenames = [ + f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json", + "unet/diffusion_pytorch_model.safetensors.index.json", + "unet/diffusion_pytorch_model-00001-of-00003.safetensors", + "unet/diffusion_pytorch_model-00002-of-00003.safetensors", + "unet/diffusion_pytorch_model-00003-of-00003.safetensors", + f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors", + f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None) + assert all(variant not in f for f in model_filenames) + + def test_sharded_variants_downloaded(self): + variant = "fp16" + filenames = [ + f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json", + "unet/diffusion_pytorch_model.safetensors.index.json", + "unet/diffusion_pytorch_model-00001-of-00003.safetensors", + "unet/diffusion_pytorch_model-00002-of-00003.safetensors", + "unet/diffusion_pytorch_model-00003-of-00003.safetensors", + f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors", + f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + assert all(variant in f for f in model_filenames) + + def test_sharded_mixed_variants_downloaded(self): + variant = "fp16" + allowed_non_variant = "unet" + filenames = [ + f"vae/diffusion_pytorch_model.safetensors.index.{variant}.json", + "vae/diffusion_pytorch_model.safetensors.index.json", + "unet/diffusion_pytorch_model.safetensors.index.json", + "unet/diffusion_pytorch_model-00001-of-00003.safetensors", + "unet/diffusion_pytorch_model-00002-of-00003.safetensors", + "unet/diffusion_pytorch_model-00003-of-00003.safetensors", + f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors", + f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors", + "vae/diffusion_pytorch_model-00001-of-00003.safetensors", + "vae/diffusion_pytorch_model-00002-of-00003.safetensors", + "vae/diffusion_pytorch_model-00003-of-00003.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames) + + class ProgressBarTests(unittest.TestCase): def get_dummy_components_image_generation(self): cross_attention_dim = 8