diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 9a9afa198b4c..07da8b5e2e2e 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -104,7 +104,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No extension is replaced with ".safetensors" """ passed_components = passed_components or [] - if folder_names is not None: + if folder_names: filenames = {f for f in filenames if os.path.split(f)[0] in folder_names} # extract all components of the pipeline and their associated files @@ -141,7 +141,25 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No return True -def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]: +def filter_model_files(filenames): + """Filter model repo files for just files/folders that contain model weights""" + weight_names = [ + WEIGHTS_NAME, + SAFETENSORS_WEIGHTS_NAME, + FLAX_WEIGHTS_NAME, + ONNX_WEIGHTS_NAME, + ONNX_EXTERNAL_WEIGHTS_NAME, + ] + + if is_transformers_available(): + weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME] + + allowed_extensions = [wn.split(".")[-1] for wn in weight_names] + + return [f for f in filenames if any(f.endswith(extension) for extension in allowed_extensions)] + + +def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]: weight_names = [ WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, @@ -169,6 +187,10 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi variant_index_re = re.compile( rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$" ) + legacy_variant_file_re = re.compile(rf".*-{transformers_index_format}\.{variant}\.[a-z]+$") + legacy_variant_index_re = re.compile( + rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.{variant}\.index\.json$" + ) # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors` non_variant_file_re = re.compile( @@ -177,54 +199,68 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi # `text_encoder/pytorch_model.bin.index.json` non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json") - if variant is not None: - variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None} - variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None} - variant_filenames = variant_weights | variant_indexes - else: - variant_filenames = set() + def filter_for_compatible_extensions(filenames, ignore_patterns=None): + if not ignore_patterns: + return filenames + + # ignore patterns uses glob style patterns e.g *.safetensors but we're only + # interested in the extension name + return {f for f in filenames if not any(f.endswith(pat.lstrip("*.")) for pat in ignore_patterns)} + + def filter_with_regex(filenames, pattern_re): + return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None} + + # Group files by component + components = {} + for filename in filenames: + if not len(filename.split("/")) == 2: + components.setdefault("", []).append(filename) + continue - non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None} - non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None} - non_variant_filenames = non_variant_weights | non_variant_indexes + component, _ = filename.split("/") + components.setdefault(component, []).append(filename) - # all variant filenames will be used by default - usable_filenames = set(variant_filenames) + usable_filenames = set() + variant_filenames = set() + for component, component_filenames in components.items(): + component_filenames = filter_for_compatible_extensions(component_filenames, ignore_patterns=ignore_patterns) + + component_variants = set() + component_legacy_variants = set() + component_non_variants = set() + if variant is not None: + component_variants = filter_with_regex(component_filenames, variant_file_re) + component_variant_index_files = filter_with_regex(component_filenames, variant_index_re) + + component_legacy_variants = filter_with_regex(component_filenames, legacy_variant_file_re) + component_legacy_variant_index_files = filter_with_regex(component_filenames, legacy_variant_index_re) + + if component_variants or component_legacy_variants: + variant_filenames.update( + component_variants | component_variant_index_files + if component_variants + else component_legacy_variants | component_legacy_variant_index_files + ) - def convert_to_variant(filename): - if "index" in filename: - variant_filename = filename.replace("index", f"index.{variant}") - elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None: - variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}" else: - variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}" - return variant_filename + component_non_variants = filter_with_regex(component_filenames, non_variant_file_re) + component_variant_index_files = filter_with_regex(component_filenames, non_variant_index_re) - def find_component(filename): - if not len(filename.split("/")) == 2: - return - component = filename.split("/")[0] - return component - - def has_sharded_variant(component, variant, variant_filenames): - # If component exists check for sharded variant index filename - # If component doesn't exist check main dir for sharded variant index filename - component = component + "/" if component else "" - variant_index_re = re.compile( - rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$" - ) - return any(f for f in variant_filenames if variant_index_re.match(f) is not None) + usable_filenames.update(component_non_variants | component_variant_index_files) - for filename in non_variant_filenames: - if convert_to_variant(filename) in variant_filenames: - continue + usable_filenames.update(variant_filenames) - component = find_component(filename) - # If a sharded variant exists skip adding to allowed patterns - if has_sharded_variant(component, variant, variant_filenames): - continue + if len(variant_filenames) == 0 and variant is not None: + error_message = f"You are trying to load model files of the `variant={variant}`, but no such modeling files are available. " + raise ValueError(error_message) - usable_filenames.add(filename) + if len(variant_filenames) > 0 and usable_filenames != variant_filenames: + logger.warning( + f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n" + f"[{', '.join(variant_filenames)}]\nLoaded non-{variant} filenames:\n" + f"[{', '.join(usable_filenames - variant_filenames)}\nIf this behavior is not " + f"expected, please check your folder structure." + ) return usable_filenames, variant_filenames @@ -922,10 +958,6 @@ def _get_custom_components_and_folders( f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'." ) - if len(variant_filenames) == 0 and variant is not None: - error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." - raise ValueError(error_message) - return custom_components, folder_names @@ -933,7 +965,6 @@ def _get_ignore_patterns( passed_components, model_folder_names: List[str], model_filenames: List[str], - variant_filenames: List[str], use_safetensors: bool, from_flax: bool, allow_pickle: bool, @@ -964,16 +995,6 @@ def _get_ignore_patterns( if not use_onnx: ignore_patterns += ["*.onnx", "*.pb"] - safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")} - safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")} - if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames: - logger.warning( - f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n" - f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n" - f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not " - f"expected, please check your folder structure." - ) - else: ignore_patterns = ["*.safetensors", "*.msgpack"] @@ -981,16 +1002,6 @@ def _get_ignore_patterns( if not use_onnx: ignore_patterns += ["*.onnx", "*.pb"] - bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")} - bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")} - if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames: - logger.warning( - f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n" - f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n" - f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check " - f"your folder structure." - ) - return ignore_patterns diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 1b306b1805d8..cb60350be1b0 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -89,6 +89,7 @@ _resolve_custom_pipeline_and_cls, _unwrap_model, _update_init_kwargs_with_connected_pipeline, + filter_model_files, load_sub_model, maybe_raise_or_warn, variant_compatible_siblings, @@ -1387,10 +1388,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: revision=revision, ) - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True + allow_pickle = True if (use_safetensors is None or use_safetensors is False) else False + use_safetensors = use_safetensors if use_safetensors is not None else True allow_patterns = None ignore_patterns = None @@ -1405,6 +1404,18 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: model_info_call_error = e # save error to reraise it if model is not cached locally if not local_files_only: + config_file = hf_hub_download( + pretrained_model_name, + cls.config_name, + cache_dir=cache_dir, + revision=revision, + proxies=proxies, + force_download=force_download, + token=token, + ) + config_dict = cls._dict_from_json_file(config_file) + ignore_filenames = config_dict.pop("_ignore_files", []) + filenames = {sibling.rfilename for sibling in info.siblings} if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant): warn_msg = ( @@ -1419,61 +1430,20 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: ) logger.warning(warn_msg) - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) - - config_file = hf_hub_download( - pretrained_model_name, - cls.config_name, - cache_dir=cache_dir, - revision=revision, - proxies=proxies, - force_download=force_download, - token=token, - ) - - config_dict = cls._dict_from_json_file(config_file) - ignore_filenames = config_dict.pop("_ignore_files", []) - - # remove ignored filenames - model_filenames = set(model_filenames) - set(ignore_filenames) - variant_filenames = set(variant_filenames) - set(ignore_filenames) - + filenames = set(filenames) - set(ignore_filenames) if revision in DEPRECATED_REVISION_ARGS and version.parse( version.parse(__version__).base_version ) >= version.parse("0.22.0"): - warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames) + warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, filenames) custom_components, folder_names = _get_custom_components_and_folders( - pretrained_model_name, config_dict, filenames, variant_filenames, variant + pretrained_model_name, config_dict, filenames, variant ) - model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names} - custom_class_name = None if custom_pipeline is None and isinstance(config_dict["_class_name"], (list, tuple)): custom_pipeline = config_dict["_class_name"][0] custom_class_name = config_dict["_class_name"][1] - # all filenames compatible with variant will be added - allow_patterns = list(model_filenames) - - # allow all patterns from non-model folders - # this enables downloading schedulers, tokenizers, ... - allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names] - # add custom component files - allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()] - # add custom pipeline file - allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else [] - # also allow downloading config.json files with the model - allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names] - # also allow downloading generation_config.json of the transformers model - allow_patterns += [os.path.join(k, "generation_config.json") for k in model_folder_names] - allow_patterns += [ - SCHEDULER_CONFIG_NAME, - CONFIG_NAME, - cls.config_name, - CUSTOM_PIPELINE_FILE_NAME, - ] - load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames load_components_from_hub = len(custom_components) > 0 @@ -1506,12 +1476,15 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: expected_components, _ = cls._get_signature_keys(pipeline_class) passed_components = [k for k in expected_components if k in kwargs] + # retrieve the names of the folders containing model weights + model_folder_names = { + os.path.split(f)[0] for f in filter_model_files(filenames) if os.path.split(f)[0] in folder_names + } # retrieve all patterns that should not be downloaded and error out when needed ignore_patterns = _get_ignore_patterns( passed_components, model_folder_names, - model_filenames, - variant_filenames, + filenames, use_safetensors, from_flax, allow_pickle, @@ -1520,6 +1493,29 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: variant, ) + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, ignore_patterns=ignore_patterns + ) + + # all filenames compatible with variant will be added + allow_patterns = list(model_filenames) + + # allow all patterns from non-model folders + # this enables downloading schedulers, tokenizers, ... + allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names] + # add custom component files + allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()] + # add custom pipeline file + allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else [] + # also allow downloading config.json files with the model + allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names] + allow_patterns += [ + SCHEDULER_CONFIG_NAME, + CONFIG_NAME, + cls.config_name, + CUSTOM_PIPELINE_FILE_NAME, + ] + # Don't download any objects that are passed allow_patterns = [ p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components) diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index acf7d9d8401b..964b55fde651 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -212,6 +212,7 @@ def test_diffusers_is_compatible_no_components_only_variants(self): class VariantCompatibleSiblingsTest(unittest.TestCase): def test_only_non_variants_downloaded(self): + ignore_patterns = ["*.bin"] variant = "fp16" filenames = [ f"vae/diffusion_pytorch_model.{variant}.safetensors", @@ -222,10 +223,13 @@ def test_only_non_variants_downloaded(self): "unet/diffusion_pytorch_model.safetensors", ] - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None) + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=None, ignore_patterns=ignore_patterns + ) assert all(variant not in f for f in model_filenames) def test_only_variants_downloaded(self): + ignore_patterns = ["*.bin"] variant = "fp16" filenames = [ f"vae/diffusion_pytorch_model.{variant}.safetensors", @@ -236,10 +240,13 @@ def test_only_variants_downloaded(self): "unet/diffusion_pytorch_model.safetensors", ] - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, ignore_patterns=ignore_patterns + ) assert all(variant in f for f in model_filenames) def test_mixed_variants_downloaded(self): + ignore_patterns = ["*.bin"] variant = "fp16" non_variant_file = "text_encoder/model.safetensors" filenames = [ @@ -249,23 +256,27 @@ def test_mixed_variants_downloaded(self): f"unet/diffusion_pytorch_model.{variant}.safetensors", "unet/diffusion_pytorch_model.safetensors", ] - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, ignore_patterns=ignore_patterns + ) assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames) def test_non_variants_in_main_dir_downloaded(self): + ignore_patterns = ["*.bin"] variant = "fp16" filenames = [ f"diffusion_pytorch_model.{variant}.safetensors", "diffusion_pytorch_model.safetensors", "model.safetensors", f"model.{variant}.safetensors", - f"diffusion_pytorch_model.{variant}.safetensors", - "diffusion_pytorch_model.safetensors", ] - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None) + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=None, ignore_patterns=ignore_patterns + ) assert all(variant not in f for f in model_filenames) def test_variants_in_main_dir_downloaded(self): + ignore_patterns = ["*.bin"] variant = "fp16" filenames = [ f"diffusion_pytorch_model.{variant}.safetensors", @@ -275,23 +286,76 @@ def test_variants_in_main_dir_downloaded(self): f"diffusion_pytorch_model.{variant}.safetensors", "diffusion_pytorch_model.safetensors", ] - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, ignore_patterns=ignore_patterns + ) assert all(variant in f for f in model_filenames) def test_mixed_variants_in_main_dir_downloaded(self): + ignore_patterns = ["*.bin"] variant = "fp16" non_variant_file = "model.safetensors" filenames = [ f"diffusion_pytorch_model.{variant}.safetensors", "diffusion_pytorch_model.safetensors", "model.safetensors", - f"diffusion_pytorch_model.{variant}.safetensors", - "diffusion_pytorch_model.safetensors", ] - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, ignore_patterns=ignore_patterns + ) assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames) + def test_sharded_variants_in_main_dir_downloaded(self): + ignore_patterns = ["*.bin"] + variant = "fp16" + filenames = [ + "diffusion_pytorch_model.safetensors.index.json", + "diffusion_pytorch_model-00001-of-00003.safetensors", + "diffusion_pytorch_model-00002-of-00003.safetensors", + "diffusion_pytorch_model-00003-of-00003.safetensors", + f"diffusion_pytorch_model.{variant}-00001-of-00002.safetensors", + f"diffusion_pytorch_model.{variant}-00002-of-00002.safetensors", + f"diffusion_pytorch_model.safetensors.index.{variant}.json", + ] + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, ignore_patterns=ignore_patterns + ) + assert all(variant in f for f in model_filenames) + + def test_mixed_sharded_and_variant_in_main_dir_downloaded(self): + ignore_patterns = ["*.bin"] + variant = "fp16" + filenames = [ + "diffusion_pytorch_model.safetensors.index.json", + "diffusion_pytorch_model-00001-of-00003.safetensors", + "diffusion_pytorch_model-00002-of-00003.safetensors", + "diffusion_pytorch_model-00003-of-00003.safetensors", + f"diffusion_pytorch_model.{variant}.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, ignore_patterns=ignore_patterns + ) + assert all(variant in f for f in model_filenames) + + def test_mixed_sharded_non_variants_in_main_dir_downloaded(self): + ignore_patterns = ["*.bin"] + variant = "fp16" + filenames = [ + f"diffusion_pytorch_model.safetensors.index.{variant}.json", + "diffusion_pytorch_model.safetensors.index.json", + "diffusion_pytorch_model-00001-of-00003.safetensors", + "diffusion_pytorch_model-00002-of-00003.safetensors", + "diffusion_pytorch_model-00003-of-00003.safetensors", + f"diffusion_pytorch_model.{variant}-00001-of-00002.safetensors", + f"diffusion_pytorch_model.{variant}-00002-of-00002.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=None, ignore_patterns=ignore_patterns + ) + assert all(variant not in f for f in model_filenames) + def test_sharded_non_variants_downloaded(self): + ignore_patterns = ["*.bin"] variant = "fp16" filenames = [ f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json", @@ -302,10 +366,13 @@ def test_sharded_non_variants_downloaded(self): f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors", f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors", ] - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None) + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=None, ignore_patterns=ignore_patterns + ) assert all(variant not in f for f in model_filenames) def test_sharded_variants_downloaded(self): + ignore_patterns = ["*.bin"] variant = "fp16" filenames = [ f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json", @@ -316,10 +383,49 @@ def test_sharded_variants_downloaded(self): f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors", f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors", ] - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, ignore_patterns=ignore_patterns + ) + assert all(variant in f for f in model_filenames) + assert model_filenames == variant_filenames + + def test_single_variant_with_sharded_non_variant_downloaded(self): + ignore_patterns = ["*.bin"] + variant = "fp16" + filenames = [ + "unet/diffusion_pytorch_model.safetensors.index.json", + "unet/diffusion_pytorch_model-00001-of-00003.safetensors", + "unet/diffusion_pytorch_model-00002-of-00003.safetensors", + "unet/diffusion_pytorch_model-00003-of-00003.safetensors", + f"unet/diffusion_pytorch_model.{variant}.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, ignore_patterns=ignore_patterns + ) assert all(variant in f for f in model_filenames) + def test_mixed_single_variant_with_sharded_non_variant_downloaded(self): + ignore_patterns = ["*.bin"] + variant = "fp16" + allowed_non_variant = "unet" + filenames = [ + "vae/diffusion_pytorch_model.safetensors.index.json", + "vae/diffusion_pytorch_model-00001-of-00003.safetensors", + "vae/diffusion_pytorch_model-00002-of-00003.safetensors", + "vae/diffusion_pytorch_model-00003-of-00003.safetensors", + f"vae/diffusion_pytorch_model.{variant}.safetensors", + "unet/diffusion_pytorch_model.safetensors.index.json", + "unet/diffusion_pytorch_model-00001-of-00003.safetensors", + "unet/diffusion_pytorch_model-00002-of-00003.safetensors", + "unet/diffusion_pytorch_model-00003-of-00003.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, ignore_patterns=ignore_patterns + ) + assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames) + def test_sharded_mixed_variants_downloaded(self): + ignore_patterns = ["*.bin"] variant = "fp16" allowed_non_variant = "unet" filenames = [ @@ -335,9 +441,144 @@ def test_sharded_mixed_variants_downloaded(self): "vae/diffusion_pytorch_model-00002-of-00003.safetensors", "vae/diffusion_pytorch_model-00003-of-00003.safetensors", ] - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, ignore_patterns=ignore_patterns + ) assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames) + def test_downloading_when_no_variant_exists(self): + ignore_patterns = ["*.bin"] + variant = "fp16" + filenames = ["model.safetensors", "diffusion_pytorch_model.safetensors"] + with self.assertRaisesRegex(ValueError, "but no such modeling files are available. "): + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, ignore_patterns=ignore_patterns + ) + + def test_downloading_use_safetensors_false(self): + ignore_patterns = ["*.safetensors"] + filenames = [ + "text_encoder/model.bin", + "unet/diffusion_pytorch_model.bin", + "unet/diffusion_pytorch_model.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=None, ignore_patterns=ignore_patterns + ) + + assert all(".safetensors" not in f for f in model_filenames) + + def test_non_variant_in_main_dir_with_variant_in_subfolder(self): + ignore_patterns = ["*.bin"] + variant = "fp16" + allowed_non_variant = "diffusion_pytorch_model.safetensors" + filenames = [ + f"unet/diffusion_pytorch_model.{variant}.safetensors", + "diffusion_pytorch_model.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, ignore_patterns=ignore_patterns + ) + assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames) + + def test_download_variants_when_component_has_no_safetensors_variant(self): + ignore_patterns = None + variant = "fp16" + filenames = [ + f"unet/diffusion_pytorch_model.{variant}.bin", + "vae/diffusion_pytorch_model.safetensors", + f"vae/diffusion_pytorch_model.{variant}.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, ignore_patterns=ignore_patterns + ) + assert { + f"unet/diffusion_pytorch_model.{variant}.bin", + f"vae/diffusion_pytorch_model.{variant}.safetensors", + } == model_filenames + + def test_error_when_download_sharded_variants_when_component_has_no_safetensors_variant(self): + ignore_patterns = ["*.bin"] + variant = "fp16" + filenames = [ + f"vae/diffusion_pytorch_model.bin.index.{variant}.json", + "vae/diffusion_pytorch_model.safetensors.index.json", + f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.bin", + "vae/diffusion_pytorch_model-00001-of-00003.safetensors", + "vae/diffusion_pytorch_model-00002-of-00003.safetensors", + "vae/diffusion_pytorch_model-00003-of-00003.safetensors", + "unet/diffusion_pytorch_model.safetensors.index.json", + "unet/diffusion_pytorch_model-00001-of-00003.safetensors", + "unet/diffusion_pytorch_model-00002-of-00003.safetensors", + "unet/diffusion_pytorch_model-00003-of-00003.safetensors", + f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.bin", + ] + with self.assertRaisesRegex(ValueError, "but no such modeling files are available. "): + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, ignore_patterns=ignore_patterns + ) + + def test_download_sharded_variants_when_component_has_no_safetensors_variant_and_safetensors_false(self): + ignore_patterns = ["*.safetensors"] + allowed_non_variant = "unet" + variant = "fp16" + filenames = [ + f"vae/diffusion_pytorch_model.bin.index.{variant}.json", + "vae/diffusion_pytorch_model.safetensors.index.json", + f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.bin", + "vae/diffusion_pytorch_model-00001-of-00003.safetensors", + "vae/diffusion_pytorch_model-00002-of-00003.safetensors", + "vae/diffusion_pytorch_model-00003-of-00003.safetensors", + "unet/diffusion_pytorch_model.safetensors.index.json", + "unet/diffusion_pytorch_model-00001-of-00003.safetensors", + "unet/diffusion_pytorch_model-00002-of-00003.safetensors", + "unet/diffusion_pytorch_model-00003-of-00003.safetensors", + f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.bin", + ] + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, ignore_patterns=ignore_patterns + ) + assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames) + + def test_download_sharded_legacy_variants(self): + ignore_patterns = None + variant = "fp16" + filenames = [ + f"vae/transformer/diffusion_pytorch_model.safetensors.{variant}.index.json", + "vae/diffusion_pytorch_model.safetensors.index.json", + f"vae/diffusion_pytorch_model-00002-of-00002.{variant}.safetensors", + "vae/diffusion_pytorch_model-00001-of-00003.safetensors", + "vae/diffusion_pytorch_model-00002-of-00003.safetensors", + "vae/diffusion_pytorch_model-00003-of-00003.safetensors", + f"vae/diffusion_pytorch_model-00001-of-00002.{variant}.safetensors", + ] + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=variant, ignore_patterns=ignore_patterns + ) + assert all(variant in f for f in model_filenames) + + def test_download_onnx_models(self): + ignore_patterns = ["*.safetensors"] + filenames = [ + "vae/model.onnx", + "unet/model.onnx", + ] + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=None, ignore_patterns=ignore_patterns + ) + assert model_filenames == set(filenames) + + def test_download_flax_models(self): + ignore_patterns = ["*.safetensors", "*.bin"] + filenames = [ + "vae/diffusion_flax_model.msgpack", + "unet/diffusion_flax_model.msgpack", + ] + model_filenames, variant_filenames = variant_compatible_siblings( + filenames, variant=None, ignore_patterns=ignore_patterns + ) + assert model_filenames == set(filenames) + class ProgressBarTests(unittest.TestCase): def get_dummy_components_image_generation(self):