Skip to content

Fix mixed variant downloading #11611

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,21 +146,27 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
components[component].append(component_filename)

# If there are no component folders check the main directory for safetensors files
filtered_filenames = set()
if not components:
if variant is not None:
filtered_filenames = filter_with_regex(filenames, variant_file_re)
else:

# If no variant filenames exist check if non-variant files are available
if not filtered_filenames:
filtered_filenames = filter_with_regex(filenames, non_variant_file_re)
return any(".safetensors" in filename for filename in filtered_filenames)

# iterate over all files of a component
# check if safetensor files exist for that component
# if variant is provided check if the variant of the safetensors exists
for component, component_filenames in components.items():
matches = []
filtered_component_filenames = set()
# if variant is provided check if the variant of the safetensors exists
if variant is not None:
filtered_component_filenames = filter_with_regex(component_filenames, variant_file_re)
else:

# if variant safetensor files do not exist check for non-variants
if not filtered_component_filenames:
filtered_component_filenames = filter_with_regex(component_filenames, non_variant_file_re)
for component_filename in filtered_component_filenames:
filename, extension = os.path.splitext(component_filename)
Expand Down
14 changes: 14 additions & 0 deletions tests/pipelines/test_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,20 @@ def test_diffusers_is_compatible_no_components_only_variants(self):
]
self.assertFalse(is_safetensors_compatible(filenames))

def test_is_compatible_mixed_variants(self):
filenames = [
"unet/diffusion_pytorch_model.fp16.safetensors",
"vae/diffusion_pytorch_model.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))

def test_is_compatible_variant_and_non_safetensors(self):
filenames = [
"unet/diffusion_pytorch_model.fp16.safetensors",
"vae/diffusion_pytorch_model.bin",
]
self.assertFalse(is_safetensors_compatible(filenames, variant="fp16"))


class VariantCompatibleSiblingsTest(unittest.TestCase):
def test_only_non_variants_downloaded(self):
Expand Down
50 changes: 19 additions & 31 deletions tests/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,38 +538,26 @@ def test_download_variant_partly(self):
variant = "no_ema"

with tempfile.TemporaryDirectory() as tmpdirname:
if use_safetensors:
with self.assertRaises(OSError) as error_context:
tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-all-variants",
cache_dir=tmpdirname,
variant=variant,
use_safetensors=use_safetensors,
)
assert "Could not find the necessary `safetensors` weights" in str(error_context.exception)
else:
tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-all-variants",
cache_dir=tmpdirname,
variant=variant,
use_safetensors=use_safetensors,
)
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
files = [item for sublist in all_root_files for item in sublist]
tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-all-variants",
cache_dir=tmpdirname,
variant=variant,
use_safetensors=use_safetensors,
)
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
files = [item for sublist in all_root_files for item in sublist]

unet_files = os.listdir(os.path.join(tmpdirname, "unet"))

# Some of the downloaded files should be a non-variant file, check:
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
# only unet has "no_ema" variant
assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files
assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1
# vae, safety_checker and text_encoder should have no variant
assert (
sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
)
assert not any(f.endswith(other_format) for f in files)
unet_files = os.listdir(os.path.join(tmpdirname, "unet"))

# Some of the downloaded files should be a non-variant file, check:
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
# only unet has "no_ema" variant
assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files
assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1
# vae, safety_checker and text_encoder should have no variant
assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
assert not any(f.endswith(other_format) for f in files)

def test_download_variants_with_sharded_checkpoints(self):
# Here we test for downloading of "variant" files belonging to the `unet` and
Expand Down