From 403417e926e988b04e040eebceeaf467c3763f59 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 24 Jan 2025 10:31:26 +0100 Subject: [PATCH 01/16] update --- .../pipelines/pipeline_loading_utils.py | 24 +++---- tests/pipelines/test_pipeline_utils.py | 71 +++++++++++++++++-- 2 files changed, 78 insertions(+), 17 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 4173c49524dd..248f1eba5e2f 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -191,15 +191,6 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi # all variant filenames will be used by default usable_filenames = set(variant_filenames) - def convert_to_variant(filename): - if "index" in filename: - variant_filename = filename.replace("index", f"index.{variant}") - elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None: - variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}" - else: - variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}" - return variant_filename - def find_component(filename): if not len(filename.split("/")) == 2: return @@ -215,15 +206,22 @@ def has_sharded_variant(component, variant, variant_filenames): ) return any(f for f in variant_filenames if variant_index_re.match(f) is not None) - for filename in non_variant_filenames: - if convert_to_variant(filename) in variant_filenames: - continue + def has_variant(component, variant_filenames): + component = component + "/" if component else "" + # Check for any variant file in this component + return any(f.startswith(component) for f in variant_filenames) + for filename in non_variant_filenames: component = find_component(filename) - # If a sharded variant exists skip adding to allowed patterns + + # Determine if sharded variant exists based on index file if has_sharded_variant(component, variant, variant_filenames): continue + # If a variant exists skip adding to allowed patterns + if has_variant(component, variant_filenames): + continue + usable_filenames.add(filename) return usable_filenames, variant_filenames diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index acf7d9d8401b..d8abe3e156d0 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -259,8 +259,6 @@ def test_non_variants_in_main_dir_downloaded(self): "diffusion_pytorch_model.safetensors", "model.safetensors", f"model.{variant}.safetensors", - f"diffusion_pytorch_model.{variant}.safetensors", - "diffusion_pytorch_model.safetensors", ] model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None) assert all(variant not in f for f in model_filenames) @@ -285,12 +283,36 @@ def test_mixed_variants_in_main_dir_downloaded(self): f"diffusion_pytorch_model.{variant}.safetensors", "diffusion_pytorch_model.safetensors", "model.safetensors", - f"diffusion_pytorch_model.{variant}.safetensors", - "diffusion_pytorch_model.safetensors", ] model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames) + def test_sharded_variants_in_main_dir_downloaded(self): + variant = "fp16" + filenames = [ + "diffusion_pytorch_model.safetensors.index.json", + "diffusion_pytorch_model-00001-of-00003.safetensors", + "diffusion_pytorch_model-00002-of-00003.safetensors", + "diffusion_pytorch_model-00003-of-00003.safetensors", + f"diffusion_pytorch_model.{variant}-00001-of-00002.safetensors", + f"diffusion_pytorch_model.{variant}-00002-of-00002.safetensors", + f"diffusion_pytorch_model.safetensors.index.{variant}.json", + ] + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + assert all(variant in f for f in model_filenames) + + def test_mixed_sharded_and_variant_in_main_dir_downloaded(self): + variant = "fp16" + filenames = [ + "diffusion_pytorch_model.safetensors.index.json", + "diffusion_pytorch_model-00001-of-00003.safetensors", + "diffusion_pytorch_model-00002-of-00003.safetensors", + "diffusion_pytorch_model-00003-of-00003.safetensors", + f"diffusion_pytorch_model.{variant}.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + assert all(variant in f for f in model_filenames) + def test_sharded_non_variants_downloaded(self): variant = "fp16" filenames = [ @@ -319,6 +341,35 @@ def test_sharded_variants_downloaded(self): model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) assert all(variant in f for f in model_filenames) + def test_single_variant_with_sharded_non_variant_downloaded(self): + variant = "fp16" + filenames = [ + "unet/diffusion_pytorch_model.safetensors.index.json", + "unet/diffusion_pytorch_model-00001-of-00003.safetensors", + "unet/diffusion_pytorch_model-00002-of-00003.safetensors", + "unet/diffusion_pytorch_model-00003-of-00003.safetensors", + f"unet/diffusion_pytorch_model.{variant}.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + assert all(variant in f for f in model_filenames) + + def test_mixed_single_variant_with_sharded_non_variant_downloaded(self): + variant = "fp16" + allowed_non_variant = "unet" + filenames = [ + "vae/diffusion_pytorch_model.safetensors.index.json", + "vae/diffusion_pytorch_model-00001-of-00003.safetensors", + "vae/diffusion_pytorch_model-00002-of-00003.safetensors", + "vae/diffusion_pytorch_model-00003-of-00003.safetensors", + f"vae/diffusion_pytorch_model.{variant}.safetensors", + "unet/diffusion_pytorch_model.safetensors.index.json", + "unet/diffusion_pytorch_model-00001-of-00003.safetensors", + "unet/diffusion_pytorch_model-00002-of-00003.safetensors", + "unet/diffusion_pytorch_model-00003-of-00003.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames) + def test_sharded_mixed_variants_downloaded(self): variant = "fp16" allowed_non_variant = "unet" @@ -338,6 +389,18 @@ def test_sharded_mixed_variants_downloaded(self): model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames) + def test_empty_filenames(self): + model_filenames, variant_filenames = variant_compatible_siblings([], variant="fp16") + assert len(model_filenames) == 0 + assert len(variant_filenames) == 0 + + def test_invalid_filenames(self): + variant = "fp16" + filenames = ["invalid_file.txt", ".hidden", "model.", f"model.{variant}."] + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + assert len(model_filenames) == 0 + assert len(variant_filenames) == 0 + class ProgressBarTests(unittest.TestCase): def get_dummy_components_image_generation(self): From 9f0ae2f523b0f5d57bf7d159fa5322673797ae69 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 24 Jan 2025 15:15:36 +0100 Subject: [PATCH 02/16] update --- .../pipelines/pipeline_loading_utils.py | 19 +++---------------- tests/pipelines/test_pipeline_utils.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 248f1eba5e2f..431ef8d0092b 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -197,29 +197,16 @@ def find_component(filename): component = filename.split("/")[0] return component - def has_sharded_variant(component, variant, variant_filenames): - # If component exists check for sharded variant index filename - # If component doesn't exist check main dir for sharded variant index filename + def has_variant(filename, variant_filenames): + component = find_component(filename) component = component + "/" if component else "" - variant_index_re = re.compile( - rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$" - ) - return any(f for f in variant_filenames if variant_index_re.match(f) is not None) - def has_variant(component, variant_filenames): - component = component + "/" if component else "" # Check for any variant file in this component return any(f.startswith(component) for f in variant_filenames) for filename in non_variant_filenames: - component = find_component(filename) - - # Determine if sharded variant exists based on index file - if has_sharded_variant(component, variant, variant_filenames): - continue - # If a variant exists skip adding to allowed patterns - if has_variant(component, variant_filenames): + if has_variant(filename, variant_filenames): continue usable_filenames.add(filename) diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index d8abe3e156d0..5a18c6990d26 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -313,6 +313,20 @@ def test_mixed_sharded_and_variant_in_main_dir_downloaded(self): model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) assert all(variant in f for f in model_filenames) + def test_mixed_sharded_non_variants_in_main_dir_downloaded(self): + variant = "fp16" + filenames = [ + f"diffusion_pytorch_model.safetensors.index.{variant}.json", + "diffusion_pytorch_model.safetensors.index.json", + "diffusion_pytorch_model-00001-of-00003.safetensors", + "diffusion_pytorch_model-00002-of-00003.safetensors", + "diffusion_pytorch_model-00003-of-00003.safetensors", + f"diffusion_pytorch_model.{variant}-00001-of-00002.safetensors", + f"diffusion_pytorch_model.{variant}-00002-of-00002.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None) + assert all(variant not in f for f in model_filenames) + def test_sharded_non_variants_downloaded(self): variant = "fp16" filenames = [ From 974f67e1e2338962b38ab3a2c4f0dd1a54635455 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 24 Jan 2025 17:22:13 +0100 Subject: [PATCH 03/16] update --- src/diffusers/pipelines/pipeline_loading_utils.py | 1 - tests/pipelines/test_pipeline_utils.py | 12 ------------ 2 files changed, 13 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 431ef8d0092b..10ac86f488d6 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -201,7 +201,6 @@ def has_variant(filename, variant_filenames): component = find_component(filename) component = component + "/" if component else "" - # Check for any variant file in this component return any(f.startswith(component) for f in variant_filenames) for filename in non_variant_filenames: diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index 5a18c6990d26..332b439981e7 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -403,18 +403,6 @@ def test_sharded_mixed_variants_downloaded(self): model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames) - def test_empty_filenames(self): - model_filenames, variant_filenames = variant_compatible_siblings([], variant="fp16") - assert len(model_filenames) == 0 - assert len(variant_filenames) == 0 - - def test_invalid_filenames(self): - variant = "fp16" - filenames = ["invalid_file.txt", ".hidden", "model.", f"model.{variant}."] - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) - assert len(model_filenames) == 0 - assert len(variant_filenames) == 0 - class ProgressBarTests(unittest.TestCase): def get_dummy_components_image_generation(self): From 9f9db3bfc8b8a9bb99081e4f4a7b6c8274024a26 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 24 Jan 2025 17:37:28 +0100 Subject: [PATCH 04/16] update --- tests/pipelines/test_pipeline_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index 332b439981e7..1e4723802c0e 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -403,6 +403,12 @@ def test_sharded_mixed_variants_downloaded(self): model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames) + def test_downloading_when_no_variant_exists(self): + variant = "fp16" + filenames = ["model.safetensors", "diffusion_pytorch_model.safetensors"] + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + assert len(model_filenames) != 0 + class ProgressBarTests(unittest.TestCase): def get_dummy_components_image_generation(self): From 2089700d4bff0e2c9e8608cce258bdff6382e629 Mon Sep 17 00:00:00 2001 From: DN6 Date: Wed, 29 Jan 2025 11:37:30 +0530 Subject: [PATCH 05/16] update --- .../pipelines/pipeline_loading_utils.py | 77 +++++++++-- src/diffusers/pipelines/pipeline_utils.py | 4 +- tests/pipelines/test_pipeline_utils.py | 127 +++++++++++++++--- 3 files changed, 184 insertions(+), 24 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 10ac86f488d6..b2161362a7e2 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -141,7 +141,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No return True -def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]: +def variant_compatible_siblings(filenames, variant=None, use_safetensors=True) -> Union[List[os.PathLike], str]: weight_names = [ WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, @@ -188,24 +188,85 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None} non_variant_filenames = non_variant_weights | non_variant_indexes - # all variant filenames will be used by default - usable_filenames = set(variant_filenames) - def find_component(filename): if not len(filename.split("/")) == 2: return component = filename.split("/")[0] return component - def has_variant(filename, variant_filenames): + def convert_to_variant(filename): + if "index" in filename: + variant_filename = filename.replace("index", f"index.{variant}") + elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None: + variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}" + else: + variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}" + return variant_filename + + def has_sharded_variant(filename, variant, variant_filenames): component = find_component(filename) + # If component exists check for sharded variant index filename + # If component doesn't exist check main dir for sharded variant index filename component = component + "/" if component else "" + variant_index_re = re.compile( + rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$" + ) + return any(f for f in variant_filenames if variant_index_re.match(f) is not None) + + def has_non_sharded_variant(filename, variant, variant_filenames): + component = find_component(filename) + component = component + "/" if component else "" + base_name = filename.split("/")[-1] + + # Only apply to sharded files (those with the index format) + if not (non_variant_file_re.match(base_name) or non_variant_index_re.match(base_name)): + return False - return any(f.startswith(component) for f in variant_filenames) + # Check if there's a non-sharded variant in the same component + non_sharded_variants = [ + f + for f in variant_filenames + if f.startswith(component) and not re.search(transformers_index_format, f.split("/")[-1]) + ] + return any(non_sharded_variants) + + if use_safetensors: + # Keep only safetensors and index files + non_variant_filenames = { + f + for f in non_variant_filenames + if f.endswith(".safetensors") or non_variant_index_re.match(f.split("/")[-1]) + } + if variant is not None: + variant_filenames = { + f for f in variant_filenames if f.endswith(".safetensors") or variant_index_re.match(f.split("/")[-1]) + } + else: + # Exclude safetensors files but keep index files + non_variant_filenames = { + f + for f in non_variant_filenames + if not f.endswith(".safetensors") or non_variant_index_re.match(f.split("/")[-1]) + } + if variant is not None: + variant_filenames = { + f + for f in variant_filenames + if not f.endswith(".safetensors") or variant_index_re.match(f.split("/")[-1]) + } + + # all variant filenames will be used by default + usable_filenames = set(variant_filenames) for filename in non_variant_filenames: - # If a variant exists skip adding to allowed patterns - if has_variant(filename, variant_filenames): + if convert_to_variant(filename) in variant_filenames: + continue + + # If a sharded variant exists skip adding to allowed patterns + if has_sharded_variant(filename, variant, variant_filenames): + continue + + if has_non_sharded_variant(filename, variant, variant_filenames): continue usable_filenames.add(filename) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index d56a2ce6eb30..53629821c30a 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1375,7 +1375,9 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: ) logger.warning(warn_msg) - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, use_safetensors=use_safetensors + ) config_file = hf_hub_download( pretrained_model_name, diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index 1e4723802c0e..e5908b692bd8 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -212,6 +212,7 @@ def test_diffusers_is_compatible_no_components_only_variants(self): class VariantCompatibleSiblingsTest(unittest.TestCase): def test_only_non_variants_downloaded(self): + use_safetensors = True variant = "fp16" filenames = [ f"vae/diffusion_pytorch_model.{variant}.safetensors", @@ -222,10 +223,13 @@ def test_only_non_variants_downloaded(self): "unet/diffusion_pytorch_model.safetensors", ] - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None) + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=None, use_safetensors=use_safetensors + ) assert all(variant not in f for f in model_filenames) def test_only_variants_downloaded(self): + use_safetensors = True variant = "fp16" filenames = [ f"vae/diffusion_pytorch_model.{variant}.safetensors", @@ -236,10 +240,13 @@ def test_only_variants_downloaded(self): "unet/diffusion_pytorch_model.safetensors", ] - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, use_safetensors=use_safetensors + ) assert all(variant in f for f in model_filenames) def test_mixed_variants_downloaded(self): + use_safetensors = True variant = "fp16" non_variant_file = "text_encoder/model.safetensors" filenames = [ @@ -249,10 +256,13 @@ def test_mixed_variants_downloaded(self): f"unet/diffusion_pytorch_model.{variant}.safetensors", "unet/diffusion_pytorch_model.safetensors", ] - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, use_safetensors=use_safetensors + ) assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames) def test_non_variants_in_main_dir_downloaded(self): + use_safetensors = True variant = "fp16" filenames = [ f"diffusion_pytorch_model.{variant}.safetensors", @@ -260,10 +270,13 @@ def test_non_variants_in_main_dir_downloaded(self): "model.safetensors", f"model.{variant}.safetensors", ] - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None) + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=None, use_safetensors=use_safetensors + ) assert all(variant not in f for f in model_filenames) def test_variants_in_main_dir_downloaded(self): + use_safetensors = True variant = "fp16" filenames = [ f"diffusion_pytorch_model.{variant}.safetensors", @@ -273,10 +286,13 @@ def test_variants_in_main_dir_downloaded(self): f"diffusion_pytorch_model.{variant}.safetensors", "diffusion_pytorch_model.safetensors", ] - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, use_safetensors=use_safetensors + ) assert all(variant in f for f in model_filenames) def test_mixed_variants_in_main_dir_downloaded(self): + use_safetensors = True variant = "fp16" non_variant_file = "model.safetensors" filenames = [ @@ -284,10 +300,13 @@ def test_mixed_variants_in_main_dir_downloaded(self): "diffusion_pytorch_model.safetensors", "model.safetensors", ] - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, use_safetensors=use_safetensors + ) assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames) def test_sharded_variants_in_main_dir_downloaded(self): + use_safetensors = True variant = "fp16" filenames = [ "diffusion_pytorch_model.safetensors.index.json", @@ -298,10 +317,13 @@ def test_sharded_variants_in_main_dir_downloaded(self): f"diffusion_pytorch_model.{variant}-00002-of-00002.safetensors", f"diffusion_pytorch_model.safetensors.index.{variant}.json", ] - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, use_safetensors=use_safetensors + ) assert all(variant in f for f in model_filenames) def test_mixed_sharded_and_variant_in_main_dir_downloaded(self): + use_safetensors = True variant = "fp16" filenames = [ "diffusion_pytorch_model.safetensors.index.json", @@ -310,10 +332,13 @@ def test_mixed_sharded_and_variant_in_main_dir_downloaded(self): "diffusion_pytorch_model-00003-of-00003.safetensors", f"diffusion_pytorch_model.{variant}.safetensors", ] - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, use_safetensors=use_safetensors + ) assert all(variant in f for f in model_filenames) def test_mixed_sharded_non_variants_in_main_dir_downloaded(self): + use_safetensors = True variant = "fp16" filenames = [ f"diffusion_pytorch_model.safetensors.index.{variant}.json", @@ -324,10 +349,13 @@ def test_mixed_sharded_non_variants_in_main_dir_downloaded(self): f"diffusion_pytorch_model.{variant}-00001-of-00002.safetensors", f"diffusion_pytorch_model.{variant}-00002-of-00002.safetensors", ] - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None) + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=None, use_safetensors=use_safetensors + ) assert all(variant not in f for f in model_filenames) def test_sharded_non_variants_downloaded(self): + use_safetensors = True variant = "fp16" filenames = [ f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json", @@ -338,10 +366,13 @@ def test_sharded_non_variants_downloaded(self): f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors", f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors", ] - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None) + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=None, use_safetensors=use_safetensors + ) assert all(variant not in f for f in model_filenames) def test_sharded_variants_downloaded(self): + use_safetensors = True variant = "fp16" filenames = [ f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json", @@ -352,10 +383,13 @@ def test_sharded_variants_downloaded(self): f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors", f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors", ] - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, use_safetensors=use_safetensors + ) assert all(variant in f for f in model_filenames) def test_single_variant_with_sharded_non_variant_downloaded(self): + use_safetensors = True variant = "fp16" filenames = [ "unet/diffusion_pytorch_model.safetensors.index.json", @@ -364,10 +398,13 @@ def test_single_variant_with_sharded_non_variant_downloaded(self): "unet/diffusion_pytorch_model-00003-of-00003.safetensors", f"unet/diffusion_pytorch_model.{variant}.safetensors", ] - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, use_safetensors=use_safetensors + ) assert all(variant in f for f in model_filenames) def test_mixed_single_variant_with_sharded_non_variant_downloaded(self): + use_safetensors = True variant = "fp16" allowed_non_variant = "unet" filenames = [ @@ -381,10 +418,13 @@ def test_mixed_single_variant_with_sharded_non_variant_downloaded(self): "unet/diffusion_pytorch_model-00002-of-00003.safetensors", "unet/diffusion_pytorch_model-00003-of-00003.safetensors", ] - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, use_safetensors=use_safetensors + ) assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames) def test_sharded_mixed_variants_downloaded(self): + use_safetensors = True variant = "fp16" allowed_non_variant = "unet" filenames = [ @@ -400,15 +440,72 @@ def test_sharded_mixed_variants_downloaded(self): "vae/diffusion_pytorch_model-00002-of-00003.safetensors", "vae/diffusion_pytorch_model-00003-of-00003.safetensors", ] - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, use_safetensors=use_safetensors + ) assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames) + def test_variant_ignored_if_use_safetensors(self): + use_safetensors = True + variant = "fp16" + filenames = [ + f"vae/diffusion_pytorch_model.{variant}.bin", + f"text_encoder/model.{variant}.bin", + f"unet/diffusion_pytorch_model.{variant}.bin", + "vae/diffusion_pytorch_model.safetensors", + "text_encoder/model.safetensors", + "unet/diffusion_pytorch_model.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, use_safetensors=use_safetensors + ) + assert all(variant not in f for f in model_filenames) + def test_downloading_when_no_variant_exists(self): + use_safetensors = True variant = "fp16" filenames = ["model.safetensors", "diffusion_pytorch_model.safetensors"] - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, use_safetensors=use_safetensors + ) assert len(model_filenames) != 0 + def test_downloading_use_safetensors_no_variant_exists(self): + use_safetensors = True + variant = "fp16" + filenames = ["text_encoder/model.bin", "unet/diffusion_pytorch_model.bin"] + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, use_safetensors=use_safetensors + ) + assert all(variant not in f for f in model_filenames) + + def test_downloading_use_safetensors_false(self): + use_safetensors = False + variant = "fp16" + filenames = [ + "text_encoder/model.bin", + "unet/diffusion_pytorch_model.bin", + "unet/diffusion_pytorch_model.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, use_safetensors=use_safetensors + ) + + assert all(".safetensors" not in f for f in model_filenames) + + def test_non_variant_in_main_dir_with_variant_in_subfolder(self): + use_safetensors = True + variant = "fp16" + allowed_non_variant = "diffusion_pytorch_model.safetensors" + filenames = [ + f"unet/diffusion_pytorch_model.{variant}.safetensors", + "diffusion_pytorch_model.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, use_safetensors=use_safetensors + ) + assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames) + class ProgressBarTests(unittest.TestCase): def get_dummy_components_image_generation(self): From a4bdc970cad9be38afc4e739bc19b605f9ef3df7 Mon Sep 17 00:00:00 2001 From: DN6 Date: Wed, 29 Jan 2025 19:08:18 +0530 Subject: [PATCH 06/16] update --- .../pipelines/pipeline_loading_utils.py | 128 +++++++----------- tests/pipelines/test_pipeline_utils.py | 59 ++++++++ 2 files changed, 107 insertions(+), 80 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index b2161362a7e2..c61c14484ae7 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -177,99 +177,67 @@ def variant_compatible_siblings(filenames, variant=None, use_safetensors=True) - # `text_encoder/pytorch_model.bin.index.json` non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json") - if variant is not None: - variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None} - variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None} - variant_filenames = variant_weights | variant_indexes - else: - variant_filenames = set() + def filter_for_compatible_extensions(filenames, variant=None, use_safetensors=True): + def is_safetensors(filename): + return ".safetensors" in filename - non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None} - non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None} - non_variant_filenames = non_variant_weights | non_variant_indexes + def is_not_safetensors(filename): + return ".safetensors" not in filename - def find_component(filename): - if not len(filename.split("/")) == 2: - return - component = filename.split("/")[0] - return component - - def convert_to_variant(filename): - if "index" in filename: - variant_filename = filename.replace("index", f"index.{variant}") - elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None: - variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}" + if use_safetensors and is_safetensors_compatible(filenames): + extension_filter = is_safetensors else: - variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}" - return variant_filename - - def has_sharded_variant(filename, variant, variant_filenames): - component = find_component(filename) - # If component exists check for sharded variant index filename - # If component doesn't exist check main dir for sharded variant index filename - component = component + "/" if component else "" - variant_index_re = re.compile( - rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$" - ) - return any(f for f in variant_filenames if variant_index_re.match(f) is not None) - - def has_non_sharded_variant(filename, variant, variant_filenames): - component = find_component(filename) - component = component + "/" if component else "" - base_name = filename.split("/")[-1] - - # Only apply to sharded files (those with the index format) - if not (non_variant_file_re.match(base_name) or non_variant_index_re.match(base_name)): - return False + extension_filter = is_not_safetensors - # Check if there's a non-sharded variant in the same component - non_sharded_variants = [ - f - for f in variant_filenames - if f.startswith(component) and not re.search(transformers_index_format, f.split("/")[-1]) - ] - return any(non_sharded_variants) - - if use_safetensors: - # Keep only safetensors and index files - non_variant_filenames = { - f - for f in non_variant_filenames - if f.endswith(".safetensors") or non_variant_index_re.match(f.split("/")[-1]) + tensor_files = {f for f in filenames if extension_filter(f)} + non_variant_indexes = { + f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None and extension_filter(f) } - if variant is not None: - variant_filenames = { - f for f in variant_filenames if f.endswith(".safetensors") or variant_index_re.match(f.split("/")[-1]) - } - else: - # Exclude safetensors files but keep index files - non_variant_filenames = { + variant_indexes = { f - for f in non_variant_filenames - if not f.endswith(".safetensors") or non_variant_index_re.match(f.split("/")[-1]) + for f in filenames + if variant is not None and variant_index_re.match(f.split("/")[-1]) is not None and extension_filter(f) } - if variant is not None: - variant_filenames = { - f - for f in variant_filenames - if not f.endswith(".safetensors") or variant_index_re.match(f.split("/")[-1]) - } - # all variant filenames will be used by default - usable_filenames = set(variant_filenames) + return tensor_files | non_variant_indexes | variant_indexes - for filename in non_variant_filenames: - if convert_to_variant(filename) in variant_filenames: - continue + def filter_for_weights_and_indexes(filenames, file_re, index_re): + weights = {f for f in filenames if file_re.match(f.split("/")[-1]) is not None} + indexes = {f for f in filenames if index_re.match(f.split("/")[-1]) is not None} + filtered_filenames = weights | indexes - # If a sharded variant exists skip adding to allowed patterns - if has_sharded_variant(filename, variant, variant_filenames): - continue + return filtered_filenames - if has_non_sharded_variant(filename, variant, variant_filenames): + # Group files by component + components = {} + for filename in filenames: + if not len(filename.split("/")) == 2: + components.setdefault("", []).append(filename) continue - usable_filenames.add(filename) + component, _ = filename.split("/") + components.setdefault(component, []).append(filename) + + usable_filenames = set() + variant_filenames = set() + for component, component_filenames in components.items(): + component_filenames = filter_for_compatible_extensions( + component_filenames, variant=variant, use_safetensors=use_safetensors + ) + + component_variants = set() + if variant is not None: + component_variants = filter_for_weights_and_indexes(component_filenames, variant_file_re, variant_index_re) + + if component_variants: + variant_filenames.update(component_variants) + usable_filenames.update(component_variants) + + else: + component_non_variants = filter_for_weights_and_indexes( + component_filenames, non_variant_file_re, non_variant_index_re + ) + usable_filenames.update(component_non_variants) return usable_filenames, variant_filenames diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index e5908b692bd8..3ec2a2929f2a 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -506,6 +506,65 @@ def test_non_variant_in_main_dir_with_variant_in_subfolder(self): ) assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames) + def test_download_variants_when_component_has_no_variant(self): + use_safetensors = True + variant = "fp16" + filenames = [ + f"unet/diffusion_pytorch_model.{variant}.bin", + "vae/diffusion_pytorch_model.safetensors", + f"vae/diffusion_pytorch_model.{variant}.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, use_safetensors=use_safetensors + ) + assert { + f"unet/diffusion_pytorch_model.{variant}.bin", + f"vae/diffusion_pytorch_model.{variant}.safetensors", + } == model_filenames + + def test_download_sharded_variants_when_component_has_no_safetensors_variant(self): + use_safetensors = True + variant = "fp16" + filenames = [ + f"vae/diffusion_pytorch_model.bin.index.{variant}.json", + "vae/diffusion_pytorch_model.safetensors.index.json", + f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.bin", + "vae/diffusion_pytorch_model-00001-of-00003.safetensors", + "vae/diffusion_pytorch_model-00002-of-00003.safetensors", + "vae/diffusion_pytorch_model-00003-of-00003.safetensors", + "unet/diffusion_pytorch_model.safetensors.index.json", + "unet/diffusion_pytorch_model-00001-of-00003.safetensors", + "unet/diffusion_pytorch_model-00002-of-00003.safetensors", + "unet/diffusion_pytorch_model-00003-of-00003.safetensors", + f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.bin", + ] + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, use_safetensors=use_safetensors + ) + assert all(variant not in f for f in model_filenames) + + def test_download_sharded_variants_when_component_has_no_safetensors_variant_and_safetensors_false(self): + use_safetensors = False + allowed_non_variant = "unet" + variant = "fp16" + filenames = [ + f"vae/diffusion_pytorch_model.bin.index.{variant}.json", + "vae/diffusion_pytorch_model.safetensors.index.json", + f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.bin", + "vae/diffusion_pytorch_model-00001-of-00003.safetensors", + "vae/diffusion_pytorch_model-00002-of-00003.safetensors", + "vae/diffusion_pytorch_model-00003-of-00003.safetensors", + "unet/diffusion_pytorch_model.safetensors.index.json", + "unet/diffusion_pytorch_model-00001-of-00003.safetensors", + "unet/diffusion_pytorch_model-00002-of-00003.safetensors", + "unet/diffusion_pytorch_model-00003-of-00003.safetensors", + f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.bin", + ] + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, use_safetensors=use_safetensors + ) + assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames) + class ProgressBarTests(unittest.TestCase): def get_dummy_components_image_generation(self): From 04d7dc3afae930ab1618e97416a2d930f9309f34 Mon Sep 17 00:00:00 2001 From: DN6 Date: Wed, 29 Jan 2025 21:54:30 +0530 Subject: [PATCH 07/16] update --- src/diffusers/pipelines/pipeline_loading_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index c61c14484ae7..ec19b9aeac06 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -104,7 +104,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No extension is replaced with ".safetensors" """ passed_components = passed_components or [] - if folder_names is not None: + if folder_names: filenames = {f for f in filenames if os.path.split(f)[0] in folder_names} # extract all components of the pipeline and their associated files From c40f60cd463987eb014138dabed6796dfa671e69 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 20 Feb 2025 11:10:48 +0100 Subject: [PATCH 08/16] update --- .../pipelines/pipeline_loading_utils.py | 55 +++------ src/diffusers/pipelines/pipeline_utils.py | 112 ++++++++++-------- 2 files changed, 79 insertions(+), 88 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index ec19b9aeac06..e804d7188302 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -141,7 +141,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No return True -def variant_compatible_siblings(filenames, variant=None, use_safetensors=True) -> Union[List[os.PathLike], str]: +def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]: weight_names = [ WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, @@ -177,17 +177,9 @@ def variant_compatible_siblings(filenames, variant=None, use_safetensors=True) - # `text_encoder/pytorch_model.bin.index.json` non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json") - def filter_for_compatible_extensions(filenames, variant=None, use_safetensors=True): - def is_safetensors(filename): - return ".safetensors" in filename - - def is_not_safetensors(filename): - return ".safetensors" not in filename - - if use_safetensors and is_safetensors_compatible(filenames): - extension_filter = is_safetensors - else: - extension_filter = is_not_safetensors + def filter_for_compatible_extensions(filenames, variant=None, ignore_patterns=None): + def extension_filter(f): + return not any(f.endswith(pattern) for pattern in ignore_patterns) tensor_files = {f for f in filenames if extension_filter(f)} non_variant_indexes = { @@ -222,7 +214,7 @@ def filter_for_weights_and_indexes(filenames, file_re, index_re): variant_filenames = set() for component, component_filenames in components.items(): component_filenames = filter_for_compatible_extensions( - component_filenames, variant=variant, use_safetensors=use_safetensors + component_filenames, variant=variant, ignore_patterns=ignore_patterns ) component_variants = set() @@ -239,6 +231,18 @@ def filter_for_weights_and_indexes(filenames, file_re, index_re): ) usable_filenames.update(component_non_variants) + if len(variant_filenames) == 0 and variant is not None: + error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." + raise ValueError(error_message) + + if len(variant_filenames) > 0 and usable_filenames != variant_filenames: + logger.warning( + f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n" + f"[{', '.join(variant_filenames)}]\nLoaded non-{variant} filenames:\n" + f"[{', '.join(usable_filenames - variant_filenames)}\nIf this behavior is not " + f"expected, please check your folder structure." + ) + return usable_filenames, variant_filenames @@ -933,10 +937,6 @@ def _get_custom_components_and_folders( f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'." ) - if len(variant_filenames) == 0 and variant is not None: - error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." - raise ValueError(error_message) - return custom_components, folder_names @@ -944,7 +944,6 @@ def _get_ignore_patterns( passed_components, model_folder_names: List[str], model_filenames: List[str], - variant_filenames: List[str], use_safetensors: bool, from_flax: bool, allow_pickle: bool, @@ -975,16 +974,6 @@ def _get_ignore_patterns( if not use_onnx: ignore_patterns += ["*.onnx", "*.pb"] - safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")} - safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")} - if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames: - logger.warning( - f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n" - f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n" - f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not " - f"expected, please check your folder structure." - ) - else: ignore_patterns = ["*.safetensors", "*.msgpack"] @@ -992,16 +981,6 @@ def _get_ignore_patterns( if not use_onnx: ignore_patterns += ["*.onnx", "*.pb"] - bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")} - bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")} - if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames: - logger.warning( - f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n" - f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n" - f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check " - f"your folder structure." - ) - return ignore_patterns diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 53629821c30a..70b8f42a7806 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1343,10 +1343,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: revision=revision, ) - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True + allow_pickle = True if (use_safetensors is None or use_safetensors is False) else False + use_safetensors = use_safetensors if use_safetensors is not None else True allow_patterns = None ignore_patterns = None @@ -1361,6 +1359,18 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: model_info_call_error = e # save error to reraise it if model is not cached locally if not local_files_only: + config_file = hf_hub_download( + pretrained_model_name, + cls.config_name, + cache_dir=cache_dir, + revision=revision, + proxies=proxies, + force_download=force_download, + token=token, + ) + config_dict = cls._dict_from_json_file(config_file) + ignore_filenames = config_dict.pop("_ignore_files", []) + filenames = {sibling.rfilename for sibling in info.siblings} if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant): warn_msg = ( @@ -1375,61 +1385,20 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: ) logger.warning(warn_msg) - model_filenames, variant_filenames = variant_compatible_siblings( - filenames, variant=variant, use_safetensors=use_safetensors - ) - - config_file = hf_hub_download( - pretrained_model_name, - cls.config_name, - cache_dir=cache_dir, - revision=revision, - proxies=proxies, - force_download=force_download, - token=token, - ) - - config_dict = cls._dict_from_json_file(config_file) - ignore_filenames = config_dict.pop("_ignore_files", []) - - # remove ignored filenames - model_filenames = set(model_filenames) - set(ignore_filenames) - variant_filenames = set(variant_filenames) - set(ignore_filenames) - + filenames = set(filenames) - set(ignore_filenames) if revision in DEPRECATED_REVISION_ARGS and version.parse( version.parse(__version__).base_version ) >= version.parse("0.22.0"): - warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames) + warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, filenames) custom_components, folder_names = _get_custom_components_and_folders( - pretrained_model_name, config_dict, filenames, variant_filenames, variant + pretrained_model_name, config_dict, filenames, variant ) - model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names} - custom_class_name = None if custom_pipeline is None and isinstance(config_dict["_class_name"], (list, tuple)): custom_pipeline = config_dict["_class_name"][0] custom_class_name = config_dict["_class_name"][1] - # all filenames compatible with variant will be added - allow_patterns = list(model_filenames) - - # allow all patterns from non-model folders - # this enables downloading schedulers, tokenizers, ... - allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names] - # add custom component files - allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()] - # add custom pipeline file - allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else [] - # also allow downloading config.json files with the model - allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names] - allow_patterns += [ - SCHEDULER_CONFIG_NAME, - CONFIG_NAME, - cls.config_name, - CUSTOM_PIPELINE_FILE_NAME, - ] - load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames load_components_from_hub = len(custom_components) > 0 @@ -1446,6 +1415,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k,v in custom_components.items()])}.\n" f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." ) + model_folder_names = {os.path.split(f)[0] for f in filenames if os.path.split(f)[0] in folder_names} # retrieve passed components that should not be downloaded pipeline_class = _get_pipeline_class( @@ -1466,8 +1436,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: ignore_patterns = _get_ignore_patterns( passed_components, model_folder_names, - model_filenames, - variant_filenames, + filenames, use_safetensors, from_flax, allow_pickle, @@ -1476,6 +1445,49 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: variant, ) + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, ignore_patterns=ignore_patterns + ) + + safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")} + safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")} + if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames: + logger.warning( + f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n" + f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n" + f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not " + f"expected, please check your folder structure." + ) + + bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")} + bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")} + if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames: + logger.warning( + f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n" + f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n" + f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check " + f"your folder structure." + ) + + # all filenames compatible with variant will be added + allow_patterns = list(model_filenames) + + # allow all patterns from non-model folders + # this enables downloading schedulers, tokenizers, ... + allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names] + # add custom component files + allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()] + # add custom pipeline file + allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else [] + # also allow downloading config.json files with the model + allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names] + allow_patterns += [ + SCHEDULER_CONFIG_NAME, + CONFIG_NAME, + cls.config_name, + CUSTOM_PIPELINE_FILE_NAME, + ] + # Don't download any objects that are passed allow_patterns = [ p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components) From ac4c23c154a49d31bc9211abbf6e0073818d83f5 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 20 Feb 2025 13:57:18 +0100 Subject: [PATCH 09/16] update --- .../pipelines/pipeline_loading_utils.py | 55 ++++++++----------- tests/pipelines/test_pipeline_utils.py | 54 +++++++++--------- 2 files changed, 51 insertions(+), 58 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index e804d7188302..f8db96ade21d 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -165,6 +165,7 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) - variant_file_re = re.compile( rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$" ) + legacy_variant_file_re = re.compile(rf".*-{transformers_index_format}\.{variant}\.[a-z]+$") # `text_encoder/pytorch_model.bin.index.fp16.json` variant_index_re = re.compile( rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$" @@ -177,28 +178,16 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) - # `text_encoder/pytorch_model.bin.index.json` non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json") - def filter_for_compatible_extensions(filenames, variant=None, ignore_patterns=None): - def extension_filter(f): - return not any(f.endswith(pattern) for pattern in ignore_patterns) + def filter_for_compatible_extensions(filenames, ignore_patterns=None): + if not ignore_patterns: + return filenames - tensor_files = {f for f in filenames if extension_filter(f)} - non_variant_indexes = { - f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None and extension_filter(f) - } - variant_indexes = { - f - for f in filenames - if variant is not None and variant_index_re.match(f.split("/")[-1]) is not None and extension_filter(f) - } - - return tensor_files | non_variant_indexes | variant_indexes + # ignore patterns uses glob style patterns e.g *.safetensors but we're only + # interested in the extension name + return {f for f in filenames if not any(f.endswith(pat.lstrip("*.")) for pat in ignore_patterns)} - def filter_for_weights_and_indexes(filenames, file_re, index_re): - weights = {f for f in filenames if file_re.match(f.split("/")[-1]) is not None} - indexes = {f for f in filenames if index_re.match(f.split("/")[-1]) is not None} - filtered_filenames = weights | indexes - - return filtered_filenames + def filter_with_regex(filenames, pattern_re): + return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None} # Group files by component components = {} @@ -213,23 +202,27 @@ def filter_for_weights_and_indexes(filenames, file_re, index_re): usable_filenames = set() variant_filenames = set() for component, component_filenames in components.items(): - component_filenames = filter_for_compatible_extensions( - component_filenames, variant=variant, ignore_patterns=ignore_patterns - ) + component_filenames = filter_for_compatible_extensions(component_filenames, ignore_patterns=ignore_patterns) component_variants = set() + component_legacy_variants = set() + component_non_variants = set() if variant is not None: - component_variants = filter_for_weights_and_indexes(component_filenames, variant_file_re, variant_index_re) + component_variants = filter_with_regex(component_filenames, variant_file_re) + component_legacy_variants = filter_with_regex(component_filenames, legacy_variant_file_re) + component_variant_index_files = filter_with_regex(component_filenames, variant_index_re) - if component_variants: - variant_filenames.update(component_variants) - usable_filenames.update(component_variants) + variant_filenames.update( + component_variants if component_variants else component_legacy_variants | component_variant_index_files + ) else: - component_non_variants = filter_for_weights_and_indexes( - component_filenames, non_variant_file_re, non_variant_index_re - ) - usable_filenames.update(component_non_variants) + component_non_variants = filter_with_regex(component_filenames, non_variant_file_re) + component_variant_index_files = filter_with_regex(component_filenames, non_variant_index_re) + + usable_filenames.update(component_non_variants | component_variant_index_files) + + usable_filenames.update(variant_filenames) if len(variant_filenames) == 0 and variant is not None: error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index 3ec2a2929f2a..fbee044ec20a 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -1,6 +1,7 @@ import contextlib import io import re +from shutil import ignore_patterns import unittest import torch @@ -212,7 +213,7 @@ def test_diffusers_is_compatible_no_components_only_variants(self): class VariantCompatibleSiblingsTest(unittest.TestCase): def test_only_non_variants_downloaded(self): - use_safetensors = True + ignore_patterns = ["*.bin", "*.msgpack"] variant = "fp16" filenames = [ f"vae/diffusion_pytorch_model.{variant}.safetensors", @@ -224,12 +225,12 @@ def test_only_non_variants_downloaded(self): ] model_filenames, variant_filenames = variant_compatible_siblings( - filenames, variant=None, use_safetensors=use_safetensors + filenames, variant=None, ignore_patterns=ignore_patterns ) assert all(variant not in f for f in model_filenames) def test_only_variants_downloaded(self): - use_safetensors = True + ignore_patterns = ["*.bin", "*.msgpack"] variant = "fp16" filenames = [ f"vae/diffusion_pytorch_model.{variant}.safetensors", @@ -241,12 +242,12 @@ def test_only_variants_downloaded(self): ] model_filenames, variant_filenames = variant_compatible_siblings( - filenames, variant=variant, use_safetensors=use_safetensors + filenames, variant=variant, ignore_patterns=ignore_patterns ) assert all(variant in f for f in model_filenames) def test_mixed_variants_downloaded(self): - use_safetensors = True + ignore_patterns = ["*.bin", "*.msgpack"] variant = "fp16" non_variant_file = "text_encoder/model.safetensors" filenames = [ @@ -257,12 +258,12 @@ def test_mixed_variants_downloaded(self): "unet/diffusion_pytorch_model.safetensors", ] model_filenames, variant_filenames = variant_compatible_siblings( - filenames, variant=variant, use_safetensors=use_safetensors + filenames, variant=variant, ignore_patterns=ignore_patterns ) assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames) def test_non_variants_in_main_dir_downloaded(self): - use_safetensors = True + ignore_patterns = ["*.bin", "*.msgpack"] variant = "fp16" filenames = [ f"diffusion_pytorch_model.{variant}.safetensors", @@ -271,12 +272,11 @@ def test_non_variants_in_main_dir_downloaded(self): f"model.{variant}.safetensors", ] model_filenames, variant_filenames = variant_compatible_siblings( - filenames, variant=None, use_safetensors=use_safetensors + filenames, variant=None, ignore_patterns=ignore_patterns ) assert all(variant not in f for f in model_filenames) def test_variants_in_main_dir_downloaded(self): - use_safetensors = True variant = "fp16" filenames = [ f"diffusion_pytorch_model.{variant}.safetensors", @@ -287,7 +287,7 @@ def test_variants_in_main_dir_downloaded(self): "diffusion_pytorch_model.safetensors", ] model_filenames, variant_filenames = variant_compatible_siblings( - filenames, variant=variant, use_safetensors=use_safetensors + filenames, variant=variant, ignore_patterns=ignore_patterns ) assert all(variant in f for f in model_filenames) @@ -301,7 +301,7 @@ def test_mixed_variants_in_main_dir_downloaded(self): "model.safetensors", ] model_filenames, variant_filenames = variant_compatible_siblings( - filenames, variant=variant, use_safetensors=use_safetensors + filenames, variant=variant, ignore_patterns=ignore_patterns ) assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames) @@ -318,7 +318,7 @@ def test_sharded_variants_in_main_dir_downloaded(self): f"diffusion_pytorch_model.safetensors.index.{variant}.json", ] model_filenames, variant_filenames = variant_compatible_siblings( - filenames, variant=variant, use_safetensors=use_safetensors + filenames, variant=variant, ignore_patterns=ignore_patterns ) assert all(variant in f for f in model_filenames) @@ -333,7 +333,7 @@ def test_mixed_sharded_and_variant_in_main_dir_downloaded(self): f"diffusion_pytorch_model.{variant}.safetensors", ] model_filenames, variant_filenames = variant_compatible_siblings( - filenames, variant=variant, use_safetensors=use_safetensors + filenames, variant=variant, ignore_patterns=ignore_patterns ) assert all(variant in f for f in model_filenames) @@ -350,7 +350,7 @@ def test_mixed_sharded_non_variants_in_main_dir_downloaded(self): f"diffusion_pytorch_model.{variant}-00002-of-00002.safetensors", ] model_filenames, variant_filenames = variant_compatible_siblings( - filenames, variant=None, use_safetensors=use_safetensors + filenames, variant=None, ignore_patterns=ignore_patterns ) assert all(variant not in f for f in model_filenames) @@ -367,7 +367,7 @@ def test_sharded_non_variants_downloaded(self): f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors", ] model_filenames, variant_filenames = variant_compatible_siblings( - filenames, variant=None, use_safetensors=use_safetensors + filenames, variant=None, ignore_patterns=ignore_patterns ) assert all(variant not in f for f in model_filenames) @@ -384,7 +384,7 @@ def test_sharded_variants_downloaded(self): f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors", ] model_filenames, variant_filenames = variant_compatible_siblings( - filenames, variant=variant, use_safetensors=use_safetensors + filenames, variant=variant, ignore_patterns=ignore_patterns ) assert all(variant in f for f in model_filenames) @@ -399,7 +399,7 @@ def test_single_variant_with_sharded_non_variant_downloaded(self): f"unet/diffusion_pytorch_model.{variant}.safetensors", ] model_filenames, variant_filenames = variant_compatible_siblings( - filenames, variant=variant, use_safetensors=use_safetensors + filenames, variant=variant, ignore_patterns=ignore_patterns ) assert all(variant in f for f in model_filenames) @@ -419,7 +419,7 @@ def test_mixed_single_variant_with_sharded_non_variant_downloaded(self): "unet/diffusion_pytorch_model-00003-of-00003.safetensors", ] model_filenames, variant_filenames = variant_compatible_siblings( - filenames, variant=variant, use_safetensors=use_safetensors + filenames, variant=variant, ignore_patterns=ignore_patterns ) assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames) @@ -441,7 +441,7 @@ def test_sharded_mixed_variants_downloaded(self): "vae/diffusion_pytorch_model-00003-of-00003.safetensors", ] model_filenames, variant_filenames = variant_compatible_siblings( - filenames, variant=variant, use_safetensors=use_safetensors + filenames, variant=variant, ignore_patterns=ignore_patterns ) assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames) @@ -457,7 +457,7 @@ def test_variant_ignored_if_use_safetensors(self): "unet/diffusion_pytorch_model.safetensors", ] model_filenames, variant_filenames = variant_compatible_siblings( - filenames, variant=variant, use_safetensors=use_safetensors + filenames, variant=variant, ignore_patterns=ignore_patterns ) assert all(variant not in f for f in model_filenames) @@ -466,7 +466,7 @@ def test_downloading_when_no_variant_exists(self): variant = "fp16" filenames = ["model.safetensors", "diffusion_pytorch_model.safetensors"] model_filenames, variant_filenames = variant_compatible_siblings( - filenames, variant=variant, use_safetensors=use_safetensors + filenames, variant=variant, ignore_patterns=ignore_patterns ) assert len(model_filenames) != 0 @@ -475,7 +475,7 @@ def test_downloading_use_safetensors_no_variant_exists(self): variant = "fp16" filenames = ["text_encoder/model.bin", "unet/diffusion_pytorch_model.bin"] model_filenames, variant_filenames = variant_compatible_siblings( - filenames, variant=variant, use_safetensors=use_safetensors + filenames, variant=variant, ignore_patterns=ignore_patterns ) assert all(variant not in f for f in model_filenames) @@ -488,7 +488,7 @@ def test_downloading_use_safetensors_false(self): "unet/diffusion_pytorch_model.safetensors", ] model_filenames, variant_filenames = variant_compatible_siblings( - filenames, variant=variant, use_safetensors=use_safetensors + filenames, variant=variant, ignore_patterns=ignore_patterns ) assert all(".safetensors" not in f for f in model_filenames) @@ -502,7 +502,7 @@ def test_non_variant_in_main_dir_with_variant_in_subfolder(self): "diffusion_pytorch_model.safetensors", ] model_filenames, variant_filenames = variant_compatible_siblings( - filenames, variant=variant, use_safetensors=use_safetensors + filenames, variant=variant, ignore_patterns=ignore_patterns ) assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames) @@ -515,7 +515,7 @@ def test_download_variants_when_component_has_no_variant(self): f"vae/diffusion_pytorch_model.{variant}.safetensors", ] model_filenames, variant_filenames = variant_compatible_siblings( - filenames, variant=variant, use_safetensors=use_safetensors + filenames, variant=variant, ignore_patterns=ignore_patterns ) assert { f"unet/diffusion_pytorch_model.{variant}.bin", @@ -539,7 +539,7 @@ def test_download_sharded_variants_when_component_has_no_safetensors_variant(sel f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.bin", ] model_filenames, variant_filenames = variant_compatible_siblings( - filenames, variant=variant, use_safetensors=use_safetensors + filenames, variant=variant, ignore_patterns=ignore_patterns ) assert all(variant not in f for f in model_filenames) @@ -561,7 +561,7 @@ def test_download_sharded_variants_when_component_has_no_safetensors_variant_and f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.bin", ] model_filenames, variant_filenames = variant_compatible_siblings( - filenames, variant=variant, use_safetensors=use_safetensors + filenames, variant=variant, ignore_patterns=ignore_patterns ) assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames) From 420c78cb9049c462d78a4808cdef7b5c215fd1cc Mon Sep 17 00:00:00 2001 From: DN6 Date: Thu, 20 Feb 2025 21:38:27 +0530 Subject: [PATCH 10/16] update --- .../pipelines/pipeline_loading_utils.py | 16 +++- tests/pipelines/test_pipeline_utils.py | 89 +++++++------------ 2 files changed, 44 insertions(+), 61 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index f8db96ade21d..d01469e50f8c 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -165,11 +165,14 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) - variant_file_re = re.compile( rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$" ) - legacy_variant_file_re = re.compile(rf".*-{transformers_index_format}\.{variant}\.[a-z]+$") # `text_encoder/pytorch_model.bin.index.fp16.json` variant_index_re = re.compile( rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$" ) + legacy_variant_file_re = re.compile(rf".*-{transformers_index_format}\.{variant}\.[a-z]+$") + legacy_variant_index_re = re.compile( + rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.{variant}\.index\.json$" + ) # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors` non_variant_file_re = re.compile( @@ -209,11 +212,16 @@ def filter_with_regex(filenames, pattern_re): component_non_variants = set() if variant is not None: component_variants = filter_with_regex(component_filenames, variant_file_re) - component_legacy_variants = filter_with_regex(component_filenames, legacy_variant_file_re) component_variant_index_files = filter_with_regex(component_filenames, variant_index_re) + component_legacy_variants = filter_with_regex(component_filenames, legacy_variant_file_re) + component_legacy_variant_index_files = filter_with_regex(component_filenames, legacy_variant_index_re) + + if component_variants: variant_filenames.update( - component_variants if component_variants else component_legacy_variants | component_variant_index_files + component_variants | component_variant_index_files + if component_variants + else component_legacy_variants | component_legacy_variant_index_files ) else: @@ -225,7 +233,7 @@ def filter_with_regex(filenames, pattern_re): usable_filenames.update(variant_filenames) if len(variant_filenames) == 0 and variant is not None: - error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." + error_message = f"You are trying to load model files of the `variant={variant}`, but no such modeling files are available. " raise ValueError(error_message) if len(variant_filenames) > 0 and usable_filenames != variant_filenames: diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index fbee044ec20a..2e222f14cbaf 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -1,7 +1,6 @@ import contextlib import io import re -from shutil import ignore_patterns import unittest import torch @@ -213,7 +212,7 @@ def test_diffusers_is_compatible_no_components_only_variants(self): class VariantCompatibleSiblingsTest(unittest.TestCase): def test_only_non_variants_downloaded(self): - ignore_patterns = ["*.bin", "*.msgpack"] + ignore_patterns = ["*.bin"] variant = "fp16" filenames = [ f"vae/diffusion_pytorch_model.{variant}.safetensors", @@ -230,7 +229,7 @@ def test_only_non_variants_downloaded(self): assert all(variant not in f for f in model_filenames) def test_only_variants_downloaded(self): - ignore_patterns = ["*.bin", "*.msgpack"] + ignore_patterns = ["*.bin"] variant = "fp16" filenames = [ f"vae/diffusion_pytorch_model.{variant}.safetensors", @@ -247,7 +246,7 @@ def test_only_variants_downloaded(self): assert all(variant in f for f in model_filenames) def test_mixed_variants_downloaded(self): - ignore_patterns = ["*.bin", "*.msgpack"] + ignore_patterns = ["*.bin"] variant = "fp16" non_variant_file = "text_encoder/model.safetensors" filenames = [ @@ -263,7 +262,7 @@ def test_mixed_variants_downloaded(self): assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames) def test_non_variants_in_main_dir_downloaded(self): - ignore_patterns = ["*.bin", "*.msgpack"] + ignore_patterns = ["*.bin"] variant = "fp16" filenames = [ f"diffusion_pytorch_model.{variant}.safetensors", @@ -277,6 +276,7 @@ def test_non_variants_in_main_dir_downloaded(self): assert all(variant not in f for f in model_filenames) def test_variants_in_main_dir_downloaded(self): + ignore_patterns = ["*.bin"] variant = "fp16" filenames = [ f"diffusion_pytorch_model.{variant}.safetensors", @@ -292,7 +292,7 @@ def test_variants_in_main_dir_downloaded(self): assert all(variant in f for f in model_filenames) def test_mixed_variants_in_main_dir_downloaded(self): - use_safetensors = True + ignore_patterns = ["*.bin"] variant = "fp16" non_variant_file = "model.safetensors" filenames = [ @@ -306,7 +306,7 @@ def test_mixed_variants_in_main_dir_downloaded(self): assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames) def test_sharded_variants_in_main_dir_downloaded(self): - use_safetensors = True + ignore_patterns = ["*.bin"] variant = "fp16" filenames = [ "diffusion_pytorch_model.safetensors.index.json", @@ -323,7 +323,7 @@ def test_sharded_variants_in_main_dir_downloaded(self): assert all(variant in f for f in model_filenames) def test_mixed_sharded_and_variant_in_main_dir_downloaded(self): - use_safetensors = True + ignore_patterns = ["*.bin"] variant = "fp16" filenames = [ "diffusion_pytorch_model.safetensors.index.json", @@ -338,7 +338,7 @@ def test_mixed_sharded_and_variant_in_main_dir_downloaded(self): assert all(variant in f for f in model_filenames) def test_mixed_sharded_non_variants_in_main_dir_downloaded(self): - use_safetensors = True + ignore_patterns = ["*.bin"] variant = "fp16" filenames = [ f"diffusion_pytorch_model.safetensors.index.{variant}.json", @@ -355,7 +355,7 @@ def test_mixed_sharded_non_variants_in_main_dir_downloaded(self): assert all(variant not in f for f in model_filenames) def test_sharded_non_variants_downloaded(self): - use_safetensors = True + ignore_patterns = ["*.bin"] variant = "fp16" filenames = [ f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json", @@ -372,7 +372,7 @@ def test_sharded_non_variants_downloaded(self): assert all(variant not in f for f in model_filenames) def test_sharded_variants_downloaded(self): - use_safetensors = True + ignore_patterns = ["*.bin"] variant = "fp16" filenames = [ f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json", @@ -387,9 +387,10 @@ def test_sharded_variants_downloaded(self): filenames, variant=variant, ignore_patterns=ignore_patterns ) assert all(variant in f for f in model_filenames) + assert model_filenames == variant_filenames def test_single_variant_with_sharded_non_variant_downloaded(self): - use_safetensors = True + ignore_patterns = ["*.bin"] variant = "fp16" filenames = [ "unet/diffusion_pytorch_model.safetensors.index.json", @@ -404,7 +405,7 @@ def test_single_variant_with_sharded_non_variant_downloaded(self): assert all(variant in f for f in model_filenames) def test_mixed_single_variant_with_sharded_non_variant_downloaded(self): - use_safetensors = True + ignore_patterns = ["*.bin"] variant = "fp16" allowed_non_variant = "unet" filenames = [ @@ -424,7 +425,7 @@ def test_mixed_single_variant_with_sharded_non_variant_downloaded(self): assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames) def test_sharded_mixed_variants_downloaded(self): - use_safetensors = True + ignore_patterns = ["*.bin"] variant = "fp16" allowed_non_variant = "unet" filenames = [ @@ -445,56 +446,30 @@ def test_sharded_mixed_variants_downloaded(self): ) assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames) - def test_variant_ignored_if_use_safetensors(self): - use_safetensors = True - variant = "fp16" - filenames = [ - f"vae/diffusion_pytorch_model.{variant}.bin", - f"text_encoder/model.{variant}.bin", - f"unet/diffusion_pytorch_model.{variant}.bin", - "vae/diffusion_pytorch_model.safetensors", - "text_encoder/model.safetensors", - "unet/diffusion_pytorch_model.safetensors", - ] - model_filenames, variant_filenames = variant_compatible_siblings( - filenames, variant=variant, ignore_patterns=ignore_patterns - ) - assert all(variant not in f for f in model_filenames) - def test_downloading_when_no_variant_exists(self): - use_safetensors = True + ignore_patterns = ["*.bin"] variant = "fp16" filenames = ["model.safetensors", "diffusion_pytorch_model.safetensors"] - model_filenames, variant_filenames = variant_compatible_siblings( - filenames, variant=variant, ignore_patterns=ignore_patterns - ) - assert len(model_filenames) != 0 - - def test_downloading_use_safetensors_no_variant_exists(self): - use_safetensors = True - variant = "fp16" - filenames = ["text_encoder/model.bin", "unet/diffusion_pytorch_model.bin"] - model_filenames, variant_filenames = variant_compatible_siblings( - filenames, variant=variant, ignore_patterns=ignore_patterns - ) - assert all(variant not in f for f in model_filenames) + with self.assertRaisesRegex(ValueError, "but no such modeling files are available. "): + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, ignore_patterns=ignore_patterns + ) def test_downloading_use_safetensors_false(self): - use_safetensors = False - variant = "fp16" + ignore_patterns = ["*.safetensors"] filenames = [ "text_encoder/model.bin", "unet/diffusion_pytorch_model.bin", "unet/diffusion_pytorch_model.safetensors", ] model_filenames, variant_filenames = variant_compatible_siblings( - filenames, variant=variant, ignore_patterns=ignore_patterns + filenames, variant=None, ignore_patterns=ignore_patterns ) assert all(".safetensors" not in f for f in model_filenames) def test_non_variant_in_main_dir_with_variant_in_subfolder(self): - use_safetensors = True + ignore_patterns = ["*.bin"] variant = "fp16" allowed_non_variant = "diffusion_pytorch_model.safetensors" filenames = [ @@ -506,8 +481,8 @@ def test_non_variant_in_main_dir_with_variant_in_subfolder(self): ) assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames) - def test_download_variants_when_component_has_no_variant(self): - use_safetensors = True + def test_download_variants_when_component_has_no_safetensors_variant(self): + ignore_patterns = None variant = "fp16" filenames = [ f"unet/diffusion_pytorch_model.{variant}.bin", @@ -522,8 +497,8 @@ def test_download_variants_when_component_has_no_variant(self): f"vae/diffusion_pytorch_model.{variant}.safetensors", } == model_filenames - def test_download_sharded_variants_when_component_has_no_safetensors_variant(self): - use_safetensors = True + def test_error_when_download_sharded_variants_when_component_has_no_safetensors_variant(self): + ignore_patterns = ["*.bin"] variant = "fp16" filenames = [ f"vae/diffusion_pytorch_model.bin.index.{variant}.json", @@ -538,13 +513,13 @@ def test_download_sharded_variants_when_component_has_no_safetensors_variant(sel "unet/diffusion_pytorch_model-00003-of-00003.safetensors", f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.bin", ] - model_filenames, variant_filenames = variant_compatible_siblings( - filenames, variant=variant, ignore_patterns=ignore_patterns - ) - assert all(variant not in f for f in model_filenames) + with self.assertRaisesRegex(ValueError, "but no such modeling files are available. "): + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, ignore_patterns=ignore_patterns + ) def test_download_sharded_variants_when_component_has_no_safetensors_variant_and_safetensors_false(self): - use_safetensors = False + ignore_patterns = ["*.safetensors"] allowed_non_variant = "unet" variant = "fp16" filenames = [ From abba8e0ff813b75f7360e7e0b565d446ff5b79d6 Mon Sep 17 00:00:00 2001 From: DN6 Date: Thu, 20 Feb 2025 22:33:37 +0530 Subject: [PATCH 11/16] update --- src/diffusers/pipelines/pipeline_utils.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 70b8f42a7806..0fc9759cc5a5 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1449,26 +1449,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: filenames, variant=variant, ignore_patterns=ignore_patterns ) - safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")} - safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")} - if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames: - logger.warning( - f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n" - f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n" - f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not " - f"expected, please check your folder structure." - ) - - bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")} - bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")} - if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames: - logger.warning( - f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n" - f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n" - f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check " - f"your folder structure." - ) - # all filenames compatible with variant will be added allow_patterns = list(model_filenames) From 6899f400d530a2b6de27eca6f0f5c93cf2e05b99 Mon Sep 17 00:00:00 2001 From: DN6 Date: Thu, 20 Feb 2025 23:43:19 +0530 Subject: [PATCH 12/16] update --- .../pipelines/pipeline_loading_utils.py | 16 ++++++++++++++++ src/diffusers/pipelines/pipeline_utils.py | 5 ++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index d01469e50f8c..d70c3561d8c3 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -141,6 +141,22 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No return True +def filter_model_files(filenames): + """Filter model repo files for just files/folders that contain model weights""" + weight_names = [ + WEIGHTS_NAME, + SAFETENSORS_WEIGHTS_NAME, + FLAX_WEIGHTS_NAME, + ONNX_WEIGHTS_NAME, + ONNX_EXTERNAL_WEIGHTS_NAME, + ] + + if is_transformers_available(): + weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME] + + return [f for f in filenames if any(f.endswith(wn) for wn in weight_names)] + + def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]: weight_names = [ WEIGHTS_NAME, diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 0fc9759cc5a5..df9d76431797 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -87,6 +87,7 @@ _resolve_custom_pipeline_and_cls, _unwrap_model, _update_init_kwargs_with_connected_pipeline, + filter_model_files, load_sub_model, maybe_raise_or_warn, variant_compatible_siblings, @@ -1415,7 +1416,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k,v in custom_components.items()])}.\n" f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." ) - model_folder_names = {os.path.split(f)[0] for f in filenames if os.path.split(f)[0] in folder_names} # retrieve passed components that should not be downloaded pipeline_class = _get_pipeline_class( @@ -1432,6 +1432,9 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: expected_components, _ = cls._get_signature_keys(pipeline_class) passed_components = [k for k in expected_components if k in kwargs] + model_folder_names = { + os.path.split(f)[0] for f in filter_model_files(filenames) if os.path.split(f)[0] in folder_names + } # retrieve all patterns that should not be downloaded and error out when needed ignore_patterns = _get_ignore_patterns( passed_components, From 3db5a69b9f306d1813475830c78ae95aafc08569 Mon Sep 17 00:00:00 2001 From: DN6 Date: Thu, 20 Feb 2025 23:59:27 +0530 Subject: [PATCH 13/16] update --- src/diffusers/pipelines/pipeline_loading_utils.py | 4 +++- src/diffusers/pipelines/pipeline_utils.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index d70c3561d8c3..1915fa33cf9a 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -154,7 +154,9 @@ def filter_model_files(filenames): if is_transformers_available(): weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME] - return [f for f in filenames if any(f.endswith(wn) for wn in weight_names)] + allowed_extensions = [wn.split(".")[-1] for wn in weight_names] + + return [f for f in filenames if any(f.endswith(extension) for extension in allowed_extensions)] def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]: diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index df9d76431797..fdff756d9464 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1432,6 +1432,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: expected_components, _ = cls._get_signature_keys(pipeline_class) passed_components = [k for k in expected_components if k in kwargs] + # retrieve the names of the folders containing model weights model_folder_names = { os.path.split(f)[0] for f in filter_model_files(filenames) if os.path.split(f)[0] in folder_names } From 02b089206b1326e1e3e9367d451a6fc036d5ca68 Mon Sep 17 00:00:00 2001 From: DN6 Date: Fri, 21 Feb 2025 12:33:42 +0530 Subject: [PATCH 14/16] update --- src/diffusers/pipelines/pipeline_utils.py | 70 ++++++++++++++++------- tests/pipelines/test_pipeline_utils.py | 22 +++++++ 2 files changed, 71 insertions(+), 21 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 3d9a630a06ea..3306a2df66d9 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -22,16 +22,21 @@ import sys from dataclasses import dataclass from pathlib import Path -from typing import (Any, Callable, Dict, List, Optional, Union, get_args, - get_origin) +from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin import numpy as np import PIL.Image import requests import torch -from huggingface_hub import (DDUFEntry, ModelCard, create_repo, - hf_hub_download, model_info, read_dduf_file, - snapshot_download) +from huggingface_hub import ( + DDUFEntry, + ModelCard, + create_repo, + hf_hub_download, + model_info, + read_dduf_file, + snapshot_download, +) from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args from packaging import version from requests.exceptions import HTTPError @@ -45,28 +50,51 @@ from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin from ..quantizers.bitsandbytes.utils import _check_bnb_status from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME -from ..utils import (CONFIG_NAME, DEPRECATED_REVISION_ARGS, BaseOutput, - PushToHubMixin, is_accelerate_available, - is_accelerate_version, is_torch_npu_available, - is_torch_version, is_transformers_version, logging, - numpy_to_pil) -from ..utils.hub_utils import (_check_legacy_sharding_variant_format, - load_or_create_model_card, populate_model_card) +from ..utils import ( + CONFIG_NAME, + DEPRECATED_REVISION_ARGS, + BaseOutput, + PushToHubMixin, + is_accelerate_available, + is_accelerate_version, + is_torch_npu_available, + is_torch_version, + is_transformers_version, + logging, + numpy_to_pil, +) +from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card from ..utils.torch_utils import is_compiled_module + if is_torch_npu_available(): import torch_npu # noqa: F401 from .pipeline_loading_utils import ( - ALL_IMPORTABLE_CLASSES, CONNECTED_PIPES_KEYS, CUSTOM_PIPELINE_FILE_NAME, - LOADABLE_CLASSES, _download_dduf_file, _fetch_class_library_tuple, - _get_custom_components_and_folders, _get_custom_pipeline_class, - _get_final_device_map, _get_ignore_patterns, _get_pipeline_class, - _identify_model_variants, _maybe_raise_error_for_incorrect_transformers, - _maybe_raise_warning_for_inpainting, _resolve_custom_pipeline_and_cls, - _unwrap_model, _update_init_kwargs_with_connected_pipeline, - filter_model_files, load_sub_model, maybe_raise_or_warn, - variant_compatible_siblings, warn_deprecated_model_variant) + ALL_IMPORTABLE_CLASSES, + CONNECTED_PIPES_KEYS, + CUSTOM_PIPELINE_FILE_NAME, + LOADABLE_CLASSES, + _download_dduf_file, + _fetch_class_library_tuple, + _get_custom_components_and_folders, + _get_custom_pipeline_class, + _get_final_device_map, + _get_ignore_patterns, + _get_pipeline_class, + _identify_model_variants, + _maybe_raise_error_for_incorrect_transformers, + _maybe_raise_warning_for_inpainting, + _resolve_custom_pipeline_and_cls, + _unwrap_model, + _update_init_kwargs_with_connected_pipeline, + filter_model_files, + load_sub_model, + maybe_raise_or_warn, + variant_compatible_siblings, + warn_deprecated_model_variant, +) + if is_accelerate_available(): import accelerate diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index 2e222f14cbaf..71bea3184139 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -540,6 +540,28 @@ def test_download_sharded_variants_when_component_has_no_safetensors_variant_and ) assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames) + def test_download_onnx_models(self): + ignore_patterns = ["*.safetensors"] + filenames = [ + "vae/model.onnx", + "unet/model.onnx", + ] + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=None, ignore_patterns=ignore_patterns + ) + assert model_filenames == set(filenames) + + def test_download_flax_models(self): + ignore_patterns = ["*.safetensors", "*.bin"] + filenames = [ + "vae/diffusion_flax_model.msgpack", + "unet/diffusion_flax_model.msgpack", + ] + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=None, ignore_patterns=ignore_patterns + ) + assert model_filenames == set(filenames) + class ProgressBarTests(unittest.TestCase): def get_dummy_components_image_generation(self): From f56880506fc940c47cd9c498ef0647d15ec5d8af Mon Sep 17 00:00:00 2001 From: DN6 Date: Tue, 4 Mar 2025 06:05:14 +0530 Subject: [PATCH 15/16] update --- src/diffusers/pipelines/pipeline_loading_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index a08a7d7bbeff..07da8b5e2e2e 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -235,7 +235,7 @@ def filter_with_regex(filenames, pattern_re): component_legacy_variants = filter_with_regex(component_filenames, legacy_variant_file_re) component_legacy_variant_index_files = filter_with_regex(component_filenames, legacy_variant_index_re) - if component_variants: + if component_variants or component_legacy_variants: variant_filenames.update( component_variants | component_variant_index_files if component_variants From f35f83b9cdd55e8e37f89ec0a26268f2b8c748b8 Mon Sep 17 00:00:00 2001 From: DN6 Date: Tue, 4 Mar 2025 06:26:15 +0530 Subject: [PATCH 16/16] update --- tests/pipelines/test_pipeline_utils.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index 71bea3184139..964b55fde651 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -540,6 +540,23 @@ def test_download_sharded_variants_when_component_has_no_safetensors_variant_and ) assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames) + def test_download_sharded_legacy_variants(self): + ignore_patterns = None + variant = "fp16" + filenames = [ + f"vae/transformer/diffusion_pytorch_model.safetensors.{variant}.index.json", + "vae/diffusion_pytorch_model.safetensors.index.json", + f"vae/diffusion_pytorch_model-00002-of-00002.{variant}.safetensors", + "vae/diffusion_pytorch_model-00001-of-00003.safetensors", + "vae/diffusion_pytorch_model-00002-of-00003.safetensors", + "vae/diffusion_pytorch_model-00003-of-00003.safetensors", + f"vae/diffusion_pytorch_model-00001-of-00002.{variant}.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, ignore_patterns=ignore_patterns + ) + assert all(variant in f for f in model_filenames) + def test_download_onnx_models(self): ignore_patterns = ["*.safetensors"] filenames = [