Skip to content

Commit 826f435

Browse files
authored
Fix mixed variant downloading (#11611)
* update * update
1 parent 4af76d0 commit 826f435

File tree

3 files changed

+42
-34
lines changed

3 files changed

+42
-34
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,21 +146,27 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
146146
components[component].append(component_filename)
147147

148148
# If there are no component folders check the main directory for safetensors files
149+
filtered_filenames = set()
149150
if not components:
150151
if variant is not None:
151152
filtered_filenames = filter_with_regex(filenames, variant_file_re)
152-
else:
153+
154+
# If no variant filenames exist check if non-variant files are available
155+
if not filtered_filenames:
153156
filtered_filenames = filter_with_regex(filenames, non_variant_file_re)
154157
return any(".safetensors" in filename for filename in filtered_filenames)
155158

156159
# iterate over all files of a component
157160
# check if safetensor files exist for that component
158-
# if variant is provided check if the variant of the safetensors exists
159161
for component, component_filenames in components.items():
160162
matches = []
163+
filtered_component_filenames = set()
164+
# if variant is provided check if the variant of the safetensors exists
161165
if variant is not None:
162166
filtered_component_filenames = filter_with_regex(component_filenames, variant_file_re)
163-
else:
167+
168+
# if variant safetensor files do not exist check for non-variants
169+
if not filtered_component_filenames:
164170
filtered_component_filenames = filter_with_regex(component_filenames, non_variant_file_re)
165171
for component_filename in filtered_component_filenames:
166172
filename, extension = os.path.splitext(component_filename)

tests/pipelines/test_pipeline_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,20 @@ def test_diffusers_is_compatible_no_components_only_variants(self):
217217
]
218218
self.assertFalse(is_safetensors_compatible(filenames))
219219

220+
def test_is_compatible_mixed_variants(self):
221+
filenames = [
222+
"unet/diffusion_pytorch_model.fp16.safetensors",
223+
"vae/diffusion_pytorch_model.safetensors",
224+
]
225+
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
226+
227+
def test_is_compatible_variant_and_non_safetensors(self):
228+
filenames = [
229+
"unet/diffusion_pytorch_model.fp16.safetensors",
230+
"vae/diffusion_pytorch_model.bin",
231+
]
232+
self.assertFalse(is_safetensors_compatible(filenames, variant="fp16"))
233+
220234

221235
class VariantCompatibleSiblingsTest(unittest.TestCase):
222236
def test_only_non_variants_downloaded(self):

tests/pipelines/test_pipelines.py

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -538,38 +538,26 @@ def test_download_variant_partly(self):
538538
variant = "no_ema"
539539

540540
with tempfile.TemporaryDirectory() as tmpdirname:
541-
if use_safetensors:
542-
with self.assertRaises(OSError) as error_context:
543-
tmpdirname = StableDiffusionPipeline.download(
544-
"hf-internal-testing/stable-diffusion-all-variants",
545-
cache_dir=tmpdirname,
546-
variant=variant,
547-
use_safetensors=use_safetensors,
548-
)
549-
assert "Could not find the necessary `safetensors` weights" in str(error_context.exception)
550-
else:
551-
tmpdirname = StableDiffusionPipeline.download(
552-
"hf-internal-testing/stable-diffusion-all-variants",
553-
cache_dir=tmpdirname,
554-
variant=variant,
555-
use_safetensors=use_safetensors,
556-
)
557-
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
558-
files = [item for sublist in all_root_files for item in sublist]
541+
tmpdirname = StableDiffusionPipeline.download(
542+
"hf-internal-testing/stable-diffusion-all-variants",
543+
cache_dir=tmpdirname,
544+
variant=variant,
545+
use_safetensors=use_safetensors,
546+
)
547+
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
548+
files = [item for sublist in all_root_files for item in sublist]
559549

560-
unet_files = os.listdir(os.path.join(tmpdirname, "unet"))
561-
562-
# Some of the downloaded files should be a non-variant file, check:
563-
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
564-
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
565-
# only unet has "no_ema" variant
566-
assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files
567-
assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1
568-
# vae, safety_checker and text_encoder should have no variant
569-
assert (
570-
sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
571-
)
572-
assert not any(f.endswith(other_format) for f in files)
550+
unet_files = os.listdir(os.path.join(tmpdirname, "unet"))
551+
552+
# Some of the downloaded files should be a non-variant file, check:
553+
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
554+
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
555+
# only unet has "no_ema" variant
556+
assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files
557+
assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1
558+
# vae, safety_checker and text_encoder should have no variant
559+
assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
560+
assert not any(f.endswith(other_format) for f in files)
573561

574562
def test_download_variants_with_sharded_checkpoints(self):
575563
# Here we test for downloading of "variant" files belonging to the `unet` and

0 commit comments

Comments
 (0)