Skip to content

Fix for fetching variants only #10646

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Mar 10, 2025
145 changes: 78 additions & 67 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
extension is replaced with ".safetensors"
"""
passed_components = passed_components or []
if folder_names is not None:
if folder_names:
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}

# extract all components of the pipeline and their associated files
Expand Down Expand Up @@ -141,7 +141,25 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
return True


def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]:
def filter_model_files(filenames):
"""Filter model repo files for just files/folders that contain model weights"""
weight_names = [
WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
FLAX_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
ONNX_EXTERNAL_WEIGHTS_NAME,
]

if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]

allowed_extensions = [wn.split(".")[-1] for wn in weight_names]

return [f for f in filenames if any(f.endswith(extension) for extension in allowed_extensions)]


def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]:
weight_names = [
WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
Expand Down Expand Up @@ -169,6 +187,10 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
variant_index_re = re.compile(
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
)
legacy_variant_file_re = re.compile(rf".*-{transformers_index_format}\.{variant}\.[a-z]+$")
legacy_variant_index_re = re.compile(
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.{variant}\.index\.json$"
)

# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
non_variant_file_re = re.compile(
Expand All @@ -177,54 +199,68 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
# `text_encoder/pytorch_model.bin.index.json`
non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")

if variant is not None:
variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None}
variant_filenames = variant_weights | variant_indexes
else:
variant_filenames = set()
def filter_for_compatible_extensions(filenames, ignore_patterns=None):
if not ignore_patterns:
return filenames

# ignore patterns uses glob style patterns e.g *.safetensors but we're only
# interested in the extension name
return {f for f in filenames if not any(f.endswith(pat.lstrip("*.")) for pat in ignore_patterns)}

def filter_with_regex(filenames, pattern_re):
return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None}

# Group files by component
components = {}
for filename in filenames:
if not len(filename.split("/")) == 2:
components.setdefault("", []).append(filename)
continue

non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None}
non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None}
non_variant_filenames = non_variant_weights | non_variant_indexes
component, _ = filename.split("/")
components.setdefault(component, []).append(filename)

# all variant filenames will be used by default
usable_filenames = set(variant_filenames)
usable_filenames = set()
variant_filenames = set()
for component, component_filenames in components.items():
component_filenames = filter_for_compatible_extensions(component_filenames, ignore_patterns=ignore_patterns)

component_variants = set()
component_legacy_variants = set()
component_non_variants = set()
if variant is not None:
component_variants = filter_with_regex(component_filenames, variant_file_re)
component_variant_index_files = filter_with_regex(component_filenames, variant_index_re)

component_legacy_variants = filter_with_regex(component_filenames, legacy_variant_file_re)
component_legacy_variant_index_files = filter_with_regex(component_filenames, legacy_variant_index_re)

if component_variants or component_legacy_variants:
variant_filenames.update(
component_variants | component_variant_index_files
if component_variants
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this if is redundant here
how we intend to do with the legacy_variant?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh missed this. Great catch! Updated to account for this.

else component_legacy_variants | component_legacy_variant_index_files
)

def convert_to_variant(filename):
if "index" in filename:
variant_filename = filename.replace("index", f"index.{variant}")
elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
else:
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
return variant_filename
component_non_variants = filter_with_regex(component_filenames, non_variant_file_re)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome!

component_variant_index_files = filter_with_regex(component_filenames, non_variant_index_re)

def find_component(filename):
if not len(filename.split("/")) == 2:
return
component = filename.split("/")[0]
return component

def has_sharded_variant(component, variant, variant_filenames):
# If component exists check for sharded variant index filename
# If component doesn't exist check main dir for sharded variant index filename
component = component + "/" if component else ""
variant_index_re = re.compile(
rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
)
return any(f for f in variant_filenames if variant_index_re.match(f) is not None)
usable_filenames.update(component_non_variants | component_variant_index_files)

for filename in non_variant_filenames:
if convert_to_variant(filename) in variant_filenames:
continue
usable_filenames.update(variant_filenames)

component = find_component(filename)
# If a sharded variant exists skip adding to allowed patterns
if has_sharded_variant(component, variant, variant_filenames):
continue
if len(variant_filenames) == 0 and variant is not None:
error_message = f"You are trying to load model files of the `variant={variant}`, but no such modeling files are available. "
raise ValueError(error_message)

usable_filenames.add(filename)
if len(variant_filenames) > 0 and usable_filenames != variant_filenames:
logger.warning(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
f"[{', '.join(variant_filenames)}]\nLoaded non-{variant} filenames:\n"
f"[{', '.join(usable_filenames - variant_filenames)}\nIf this behavior is not "
f"expected, please check your folder structure."
)

return usable_filenames, variant_filenames

Expand Down Expand Up @@ -922,18 +958,13 @@ def _get_custom_components_and_folders(
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
)

if len(variant_filenames) == 0 and variant is not None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need this in this function. The variant_filenames are only being used to raise this error. It's better we consolidate it into variant_compatible_siblings

error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
raise ValueError(error_message)

return custom_components, folder_names


def _get_ignore_patterns(
passed_components,
model_folder_names: List[str],
model_filenames: List[str],
variant_filenames: List[str],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

variant_filenames is only being used to raise a warning here. We can consolidate this into variant_compatible_siblings and raise the warning there.

use_safetensors: bool,
from_flax: bool,
allow_pickle: bool,
Expand Down Expand Up @@ -964,33 +995,13 @@ def _get_ignore_patterns(
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]

safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames:
logger.warning(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not "
f"expected, please check your folder structure."
)

else:
ignore_patterns = ["*.safetensors", "*.msgpack"]

use_onnx = use_onnx if use_onnx is not None else is_onnx
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]

bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
logger.warning(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check "
f"your folder structure."
)

return ignore_patterns


Expand Down
96 changes: 46 additions & 50 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
_resolve_custom_pipeline_and_cls,
_unwrap_model,
_update_init_kwargs_with_connected_pipeline,
filter_model_files,
load_sub_model,
maybe_raise_or_warn,
variant_compatible_siblings,
Expand Down Expand Up @@ -1387,10 +1388,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
revision=revision,
)

allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
allow_pickle = True if (use_safetensors is None or use_safetensors is False) else False
use_safetensors = use_safetensors if use_safetensors is not None else True

allow_patterns = None
ignore_patterns = None
Expand All @@ -1405,6 +1404,18 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
model_info_call_error = e # save error to reraise it if model is not cached locally

if not local_files_only:
config_file = hf_hub_download(
pretrained_model_name,
cls.config_name,
cache_dir=cache_dir,
revision=revision,
proxies=proxies,
force_download=force_download,
token=token,
)
config_dict = cls._dict_from_json_file(config_file)
ignore_filenames = config_dict.pop("_ignore_files", [])

filenames = {sibling.rfilename for sibling in info.siblings}
if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
warn_msg = (
Expand All @@ -1419,61 +1430,20 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
)
logger.warning(warn_msg)

model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)

config_file = hf_hub_download(
pretrained_model_name,
cls.config_name,
cache_dir=cache_dir,
revision=revision,
proxies=proxies,
force_download=force_download,
token=token,
)

config_dict = cls._dict_from_json_file(config_file)
ignore_filenames = config_dict.pop("_ignore_files", [])

# remove ignored filenames
model_filenames = set(model_filenames) - set(ignore_filenames)
variant_filenames = set(variant_filenames) - set(ignore_filenames)

filenames = set(filenames) - set(ignore_filenames)
if revision in DEPRECATED_REVISION_ARGS and version.parse(
version.parse(__version__).base_version
) >= version.parse("0.22.0"):
warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames)
warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, filenames)

custom_components, folder_names = _get_custom_components_and_folders(
pretrained_model_name, config_dict, filenames, variant_filenames, variant
pretrained_model_name, config_dict, filenames, variant
)
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}

custom_class_name = None
if custom_pipeline is None and isinstance(config_dict["_class_name"], (list, tuple)):
custom_pipeline = config_dict["_class_name"][0]
custom_class_name = config_dict["_class_name"][1]

# all filenames compatible with variant will be added
allow_patterns = list(model_filenames)

# allow all patterns from non-model folders
# this enables downloading schedulers, tokenizers, ...
allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
# add custom component files
allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
# add custom pipeline file
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
# also allow downloading config.json files with the model
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
# also allow downloading generation_config.json of the transformers model
allow_patterns += [os.path.join(k, "generation_config.json") for k in model_folder_names]
allow_patterns += [
SCHEDULER_CONFIG_NAME,
CONFIG_NAME,
cls.config_name,
CUSTOM_PIPELINE_FILE_NAME,
]

load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames
load_components_from_hub = len(custom_components) > 0

Expand Down Expand Up @@ -1506,12 +1476,15 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
expected_components, _ = cls._get_signature_keys(pipeline_class)
passed_components = [k for k in expected_components if k in kwargs]

# retrieve the names of the folders containing model weights
model_folder_names = {
os.path.split(f)[0] for f in filter_model_files(filenames) if os.path.split(f)[0] in folder_names
}
# retrieve all patterns that should not be downloaded and error out when needed
ignore_patterns = _get_ignore_patterns(
passed_components,
model_folder_names,
model_filenames,
variant_filenames,
filenames,
use_safetensors,
from_flax,
allow_pickle,
Expand All @@ -1520,6 +1493,29 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
variant,
)

model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)

# all filenames compatible with variant will be added
allow_patterns = list(model_filenames)

# allow all patterns from non-model folders
# this enables downloading schedulers, tokenizers, ...
allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
# add custom component files
allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
# add custom pipeline file
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
# also allow downloading config.json files with the model
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
allow_patterns += [
SCHEDULER_CONFIG_NAME,
CONFIG_NAME,
cls.config_name,
CUSTOM_PIPELINE_FILE_NAME,
]

# Don't download any objects that are passed
allow_patterns = [
p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components)
Expand Down
Loading
Loading