Skip to content

Commit b5c2050

Browse files
Fix bug when variant and safetensor file does not match (#11587)
* Apply style fixes * init test Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * adjust Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * add the variant check when there are no component folders Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * update related test cases Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * update related unit test cases Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * adjust Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * Apply style fixes --------- Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 7ae546f commit b5c2050

File tree

3 files changed

+101
-50
lines changed

3 files changed

+101
-50
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@
9292
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
9393

9494

95-
def is_safetensors_compatible(filenames, passed_components=None, folder_names=None) -> bool:
95+
def is_safetensors_compatible(filenames, passed_components=None, folder_names=None, variant=None) -> bool:
9696
"""
9797
Checking for safetensors compatibility:
9898
- The model is safetensors compatible only if there is a safetensors file for each model component present in
@@ -103,6 +103,31 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
103103
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
104104
extension is replaced with ".safetensors"
105105
"""
106+
weight_names = [
107+
WEIGHTS_NAME,
108+
SAFETENSORS_WEIGHTS_NAME,
109+
FLAX_WEIGHTS_NAME,
110+
ONNX_WEIGHTS_NAME,
111+
ONNX_EXTERNAL_WEIGHTS_NAME,
112+
]
113+
114+
if is_transformers_available():
115+
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
116+
117+
# model_pytorch, diffusion_model_pytorch, ...
118+
weight_prefixes = [w.split(".")[0] for w in weight_names]
119+
# .bin, .safetensors, ...
120+
weight_suffixs = [w.split(".")[-1] for w in weight_names]
121+
# -00001-of-00002
122+
transformers_index_format = r"\d{5}-of-\d{5}"
123+
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
124+
variant_file_re = re.compile(
125+
rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
126+
)
127+
non_variant_file_re = re.compile(
128+
rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
129+
)
130+
106131
passed_components = passed_components or []
107132
if folder_names:
108133
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
@@ -122,14 +147,22 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
122147

123148
# If there are no component folders check the main directory for safetensors files
124149
if not components:
125-
return any(".safetensors" in filename for filename in filenames)
150+
if variant is not None:
151+
filtered_filenames = filter_with_regex(filenames, variant_file_re)
152+
else:
153+
filtered_filenames = filter_with_regex(filenames, non_variant_file_re)
154+
return any(".safetensors" in filename for filename in filtered_filenames)
126155

127156
# iterate over all files of a component
128157
# check if safetensor files exist for that component
129158
# if variant is provided check if the variant of the safetensors exists
130159
for component, component_filenames in components.items():
131160
matches = []
132-
for component_filename in component_filenames:
161+
if variant is not None:
162+
filtered_component_filenames = filter_with_regex(component_filenames, variant_file_re)
163+
else:
164+
filtered_component_filenames = filter_with_regex(component_filenames, non_variant_file_re)
165+
for component_filename in filtered_component_filenames:
133166
filename, extension = os.path.splitext(component_filename)
134167

135168
match_exists = extension == ".safetensors"
@@ -159,6 +192,10 @@ def filter_model_files(filenames):
159192
return [f for f in filenames if any(f.endswith(extension) for extension in allowed_extensions)]
160193

161194

195+
def filter_with_regex(filenames, pattern_re):
196+
return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None}
197+
198+
162199
def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]:
163200
weight_names = [
164201
WEIGHTS_NAME,
@@ -207,9 +244,6 @@ def filter_for_compatible_extensions(filenames, ignore_patterns=None):
207244
# interested in the extension name
208245
return {f for f in filenames if not any(f.endswith(pat.lstrip("*.")) for pat in ignore_patterns)}
209246

210-
def filter_with_regex(filenames, pattern_re):
211-
return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None}
212-
213247
# Group files by component
214248
components = {}
215249
for filename in filenames:
@@ -997,7 +1031,7 @@ def _get_ignore_patterns(
9971031
use_safetensors
9981032
and not allow_pickle
9991033
and not is_safetensors_compatible(
1000-
model_filenames, passed_components=passed_components, folder_names=model_folder_names
1034+
model_filenames, passed_components=passed_components, folder_names=model_folder_names, variant=variant
10011035
)
10021036
):
10031037
raise EnvironmentError(
@@ -1008,7 +1042,7 @@ def _get_ignore_patterns(
10081042
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
10091043

10101044
elif use_safetensors and is_safetensors_compatible(
1011-
model_filenames, passed_components=passed_components, folder_names=model_folder_names
1045+
model_filenames, passed_components=passed_components, folder_names=model_folder_names, variant=variant
10121046
):
10131047
ignore_patterns = ["*.bin", "*.msgpack"]
10141048

tests/pipelines/test_pipeline_utils.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,21 +87,24 @@ def test_all_is_compatible_variant(self):
8787
"unet/diffusion_pytorch_model.fp16.bin",
8888
"unet/diffusion_pytorch_model.fp16.safetensors",
8989
]
90-
self.assertTrue(is_safetensors_compatible(filenames))
90+
self.assertFalse(is_safetensors_compatible(filenames))
91+
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
9192

9293
def test_diffusers_model_is_compatible_variant(self):
9394
filenames = [
9495
"unet/diffusion_pytorch_model.fp16.bin",
9596
"unet/diffusion_pytorch_model.fp16.safetensors",
9697
]
97-
self.assertTrue(is_safetensors_compatible(filenames))
98+
self.assertFalse(is_safetensors_compatible(filenames))
99+
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
98100

99101
def test_diffusers_model_is_compatible_variant_mixed(self):
100102
filenames = [
101103
"unet/diffusion_pytorch_model.bin",
102104
"unet/diffusion_pytorch_model.fp16.safetensors",
103105
]
104-
self.assertTrue(is_safetensors_compatible(filenames))
106+
self.assertFalse(is_safetensors_compatible(filenames))
107+
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
105108

106109
def test_diffusers_model_is_not_compatible_variant(self):
107110
filenames = [
@@ -121,7 +124,8 @@ def test_transformer_model_is_compatible_variant(self):
121124
"text_encoder/pytorch_model.fp16.bin",
122125
"text_encoder/model.fp16.safetensors",
123126
]
124-
self.assertTrue(is_safetensors_compatible(filenames))
127+
self.assertFalse(is_safetensors_compatible(filenames))
128+
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
125129

126130
def test_transformer_model_is_not_compatible_variant(self):
127131
filenames = [
@@ -145,7 +149,8 @@ def test_transformer_model_is_compatible_variant_extra_folder(self):
145149
"unet/diffusion_pytorch_model.fp16.bin",
146150
"unet/diffusion_pytorch_model.fp16.safetensors",
147151
]
148-
self.assertTrue(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}))
152+
self.assertFalse(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}))
153+
self.assertTrue(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}, variant="fp16"))
149154

150155
def test_transformer_model_is_not_compatible_variant_extra_folder(self):
151156
filenames = [
@@ -173,7 +178,8 @@ def test_transformers_is_compatible_variant_sharded(self):
173178
"text_encoder/model.fp16-00001-of-00002.safetensors",
174179
"text_encoder/model.fp16-00001-of-00002.safetensors",
175180
]
176-
self.assertTrue(is_safetensors_compatible(filenames))
181+
self.assertFalse(is_safetensors_compatible(filenames))
182+
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
177183

178184
def test_diffusers_is_compatible_sharded(self):
179185
filenames = [
@@ -189,13 +195,15 @@ def test_diffusers_is_compatible_variant_sharded(self):
189195
"unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors",
190196
"unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors",
191197
]
192-
self.assertTrue(is_safetensors_compatible(filenames))
198+
self.assertFalse(is_safetensors_compatible(filenames))
199+
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
193200

194201
def test_diffusers_is_compatible_only_variants(self):
195202
filenames = [
196203
"unet/diffusion_pytorch_model.fp16.safetensors",
197204
]
198-
self.assertTrue(is_safetensors_compatible(filenames))
205+
self.assertFalse(is_safetensors_compatible(filenames))
206+
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
199207

200208
def test_diffusers_is_compatible_no_components(self):
201209
filenames = [

tests/pipelines/test_pipelines.py

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -538,26 +538,38 @@ def test_download_variant_partly(self):
538538
variant = "no_ema"
539539

540540
with tempfile.TemporaryDirectory() as tmpdirname:
541-
tmpdirname = StableDiffusionPipeline.download(
542-
"hf-internal-testing/stable-diffusion-all-variants",
543-
cache_dir=tmpdirname,
544-
variant=variant,
545-
use_safetensors=use_safetensors,
546-
)
547-
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
548-
files = [item for sublist in all_root_files for item in sublist]
549-
550-
unet_files = os.listdir(os.path.join(tmpdirname, "unet"))
541+
if use_safetensors:
542+
with self.assertRaises(OSError) as error_context:
543+
tmpdirname = StableDiffusionPipeline.download(
544+
"hf-internal-testing/stable-diffusion-all-variants",
545+
cache_dir=tmpdirname,
546+
variant=variant,
547+
use_safetensors=use_safetensors,
548+
)
549+
assert "Could not find the necessary `safetensors` weights" in str(error_context.exception)
550+
else:
551+
tmpdirname = StableDiffusionPipeline.download(
552+
"hf-internal-testing/stable-diffusion-all-variants",
553+
cache_dir=tmpdirname,
554+
variant=variant,
555+
use_safetensors=use_safetensors,
556+
)
557+
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
558+
files = [item for sublist in all_root_files for item in sublist]
551559

552-
# Some of the downloaded files should be a non-variant file, check:
553-
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
554-
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
555-
# only unet has "no_ema" variant
556-
assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files
557-
assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1
558-
# vae, safety_checker and text_encoder should have no variant
559-
assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
560-
assert not any(f.endswith(other_format) for f in files)
560+
unet_files = os.listdir(os.path.join(tmpdirname, "unet"))
561+
562+
# Some of the downloaded files should be a non-variant file, check:
563+
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
564+
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
565+
# only unet has "no_ema" variant
566+
assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files
567+
assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1
568+
# vae, safety_checker and text_encoder should have no variant
569+
assert (
570+
sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
571+
)
572+
assert not any(f.endswith(other_format) for f in files)
561573

562574
def test_download_variants_with_sharded_checkpoints(self):
563575
# Here we test for downloading of "variant" files belonging to the `unet` and
@@ -588,20 +600,17 @@ def test_download_legacy_variants_with_sharded_ckpts_raises_warning(self):
588600
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
589601
deprecated_warning_msg = "Warning: The repository contains sharded checkpoints for variant"
590602

591-
for is_local in [True, False]:
592-
with CaptureLogger(logger) as cap_logger:
593-
with tempfile.TemporaryDirectory() as tmpdirname:
594-
local_repo_id = repo_id
595-
if is_local:
596-
local_repo_id = snapshot_download(repo_id, cache_dir=tmpdirname)
603+
with CaptureLogger(logger) as cap_logger:
604+
with tempfile.TemporaryDirectory() as tmpdirname:
605+
local_repo_id = snapshot_download(repo_id, cache_dir=tmpdirname)
597606

598-
_ = DiffusionPipeline.from_pretrained(
599-
local_repo_id,
600-
safety_checker=None,
601-
variant="fp16",
602-
use_safetensors=True,
603-
)
604-
assert deprecated_warning_msg in str(cap_logger), "Deprecation warning not found in logs"
607+
_ = DiffusionPipeline.from_pretrained(
608+
local_repo_id,
609+
safety_checker=None,
610+
variant="fp16",
611+
use_safetensors=True,
612+
)
613+
assert deprecated_warning_msg in str(cap_logger), "Deprecation warning not found in logs"
605614

606615
def test_download_safetensors_only_variant_exists_for_model(self):
607616
variant = None
@@ -616,7 +625,7 @@ def test_download_safetensors_only_variant_exists_for_model(self):
616625
variant=variant,
617626
use_safetensors=use_safetensors,
618627
)
619-
assert "Error no file name" in str(error_context.exception)
628+
assert "Could not find the necessary `safetensors` weights" in str(error_context.exception)
620629

621630
# text encoder has fp16 variants so we can load it
622631
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -675,7 +684,7 @@ def test_download_safetensors_variant_does_not_exist_for_model(self):
675684
use_safetensors=use_safetensors,
676685
)
677686

678-
assert "Error no file name" in str(error_context.exception)
687+
assert "Could not find the necessary `safetensors` weights" in str(error_context.exception)
679688

680689
def test_download_bin_variant_does_not_exist_for_model(self):
681690
variant = "no_ema"

0 commit comments

Comments
 (0)