Skip to content

Commit 9a1810f

Browse files
authored
Fix for fetching variants only (#10646)
* update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update
1 parent 1fddee2 commit 9a1810f

File tree

3 files changed

+378
-130
lines changed

3 files changed

+378
-130
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 78 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
104104
extension is replaced with ".safetensors"
105105
"""
106106
passed_components = passed_components or []
107-
if folder_names is not None:
107+
if folder_names:
108108
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
109109

110110
# extract all components of the pipeline and their associated files
@@ -141,7 +141,25 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
141141
return True
142142

143143

144-
def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]:
144+
def filter_model_files(filenames):
145+
"""Filter model repo files for just files/folders that contain model weights"""
146+
weight_names = [
147+
WEIGHTS_NAME,
148+
SAFETENSORS_WEIGHTS_NAME,
149+
FLAX_WEIGHTS_NAME,
150+
ONNX_WEIGHTS_NAME,
151+
ONNX_EXTERNAL_WEIGHTS_NAME,
152+
]
153+
154+
if is_transformers_available():
155+
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
156+
157+
allowed_extensions = [wn.split(".")[-1] for wn in weight_names]
158+
159+
return [f for f in filenames if any(f.endswith(extension) for extension in allowed_extensions)]
160+
161+
162+
def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]:
145163
weight_names = [
146164
WEIGHTS_NAME,
147165
SAFETENSORS_WEIGHTS_NAME,
@@ -169,6 +187,10 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
169187
variant_index_re = re.compile(
170188
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
171189
)
190+
legacy_variant_file_re = re.compile(rf".*-{transformers_index_format}\.{variant}\.[a-z]+$")
191+
legacy_variant_index_re = re.compile(
192+
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.{variant}\.index\.json$"
193+
)
172194

173195
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
174196
non_variant_file_re = re.compile(
@@ -177,54 +199,68 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
177199
# `text_encoder/pytorch_model.bin.index.json`
178200
non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
179201

180-
if variant is not None:
181-
variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
182-
variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None}
183-
variant_filenames = variant_weights | variant_indexes
184-
else:
185-
variant_filenames = set()
202+
def filter_for_compatible_extensions(filenames, ignore_patterns=None):
203+
if not ignore_patterns:
204+
return filenames
205+
206+
# ignore patterns uses glob style patterns e.g *.safetensors but we're only
207+
# interested in the extension name
208+
return {f for f in filenames if not any(f.endswith(pat.lstrip("*.")) for pat in ignore_patterns)}
209+
210+
def filter_with_regex(filenames, pattern_re):
211+
return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None}
212+
213+
# Group files by component
214+
components = {}
215+
for filename in filenames:
216+
if not len(filename.split("/")) == 2:
217+
components.setdefault("", []).append(filename)
218+
continue
186219

187-
non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None}
188-
non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None}
189-
non_variant_filenames = non_variant_weights | non_variant_indexes
220+
component, _ = filename.split("/")
221+
components.setdefault(component, []).append(filename)
190222

191-
# all variant filenames will be used by default
192-
usable_filenames = set(variant_filenames)
223+
usable_filenames = set()
224+
variant_filenames = set()
225+
for component, component_filenames in components.items():
226+
component_filenames = filter_for_compatible_extensions(component_filenames, ignore_patterns=ignore_patterns)
227+
228+
component_variants = set()
229+
component_legacy_variants = set()
230+
component_non_variants = set()
231+
if variant is not None:
232+
component_variants = filter_with_regex(component_filenames, variant_file_re)
233+
component_variant_index_files = filter_with_regex(component_filenames, variant_index_re)
234+
235+
component_legacy_variants = filter_with_regex(component_filenames, legacy_variant_file_re)
236+
component_legacy_variant_index_files = filter_with_regex(component_filenames, legacy_variant_index_re)
237+
238+
if component_variants or component_legacy_variants:
239+
variant_filenames.update(
240+
component_variants | component_variant_index_files
241+
if component_variants
242+
else component_legacy_variants | component_legacy_variant_index_files
243+
)
193244

194-
def convert_to_variant(filename):
195-
if "index" in filename:
196-
variant_filename = filename.replace("index", f"index.{variant}")
197-
elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
198-
variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
199245
else:
200-
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
201-
return variant_filename
246+
component_non_variants = filter_with_regex(component_filenames, non_variant_file_re)
247+
component_variant_index_files = filter_with_regex(component_filenames, non_variant_index_re)
202248

203-
def find_component(filename):
204-
if not len(filename.split("/")) == 2:
205-
return
206-
component = filename.split("/")[0]
207-
return component
208-
209-
def has_sharded_variant(component, variant, variant_filenames):
210-
# If component exists check for sharded variant index filename
211-
# If component doesn't exist check main dir for sharded variant index filename
212-
component = component + "/" if component else ""
213-
variant_index_re = re.compile(
214-
rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
215-
)
216-
return any(f for f in variant_filenames if variant_index_re.match(f) is not None)
249+
usable_filenames.update(component_non_variants | component_variant_index_files)
217250

218-
for filename in non_variant_filenames:
219-
if convert_to_variant(filename) in variant_filenames:
220-
continue
251+
usable_filenames.update(variant_filenames)
221252

222-
component = find_component(filename)
223-
# If a sharded variant exists skip adding to allowed patterns
224-
if has_sharded_variant(component, variant, variant_filenames):
225-
continue
253+
if len(variant_filenames) == 0 and variant is not None:
254+
error_message = f"You are trying to load model files of the `variant={variant}`, but no such modeling files are available. "
255+
raise ValueError(error_message)
226256

227-
usable_filenames.add(filename)
257+
if len(variant_filenames) > 0 and usable_filenames != variant_filenames:
258+
logger.warning(
259+
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
260+
f"[{', '.join(variant_filenames)}]\nLoaded non-{variant} filenames:\n"
261+
f"[{', '.join(usable_filenames - variant_filenames)}\nIf this behavior is not "
262+
f"expected, please check your folder structure."
263+
)
228264

229265
return usable_filenames, variant_filenames
230266

@@ -922,18 +958,13 @@ def _get_custom_components_and_folders(
922958
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
923959
)
924960

925-
if len(variant_filenames) == 0 and variant is not None:
926-
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
927-
raise ValueError(error_message)
928-
929961
return custom_components, folder_names
930962

931963

932964
def _get_ignore_patterns(
933965
passed_components,
934966
model_folder_names: List[str],
935967
model_filenames: List[str],
936-
variant_filenames: List[str],
937968
use_safetensors: bool,
938969
from_flax: bool,
939970
allow_pickle: bool,
@@ -964,33 +995,13 @@ def _get_ignore_patterns(
964995
if not use_onnx:
965996
ignore_patterns += ["*.onnx", "*.pb"]
966997

967-
safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
968-
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
969-
if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames:
970-
logger.warning(
971-
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
972-
f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
973-
f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not "
974-
f"expected, please check your folder structure."
975-
)
976-
977998
else:
978999
ignore_patterns = ["*.safetensors", "*.msgpack"]
9791000

9801001
use_onnx = use_onnx if use_onnx is not None else is_onnx
9811002
if not use_onnx:
9821003
ignore_patterns += ["*.onnx", "*.pb"]
9831004

984-
bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
985-
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
986-
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
987-
logger.warning(
988-
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
989-
f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
990-
f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check "
991-
f"your folder structure."
992-
)
993-
9941005
return ignore_patterns
9951006

9961007

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 46 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
_resolve_custom_pipeline_and_cls,
9090
_unwrap_model,
9191
_update_init_kwargs_with_connected_pipeline,
92+
filter_model_files,
9293
load_sub_model,
9394
maybe_raise_or_warn,
9495
variant_compatible_siblings,
@@ -1387,10 +1388,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
13871388
revision=revision,
13881389
)
13891390

1390-
allow_pickle = False
1391-
if use_safetensors is None:
1392-
use_safetensors = True
1393-
allow_pickle = True
1391+
allow_pickle = True if (use_safetensors is None or use_safetensors is False) else False
1392+
use_safetensors = use_safetensors if use_safetensors is not None else True
13941393

13951394
allow_patterns = None
13961395
ignore_patterns = None
@@ -1405,6 +1404,18 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
14051404
model_info_call_error = e # save error to reraise it if model is not cached locally
14061405

14071406
if not local_files_only:
1407+
config_file = hf_hub_download(
1408+
pretrained_model_name,
1409+
cls.config_name,
1410+
cache_dir=cache_dir,
1411+
revision=revision,
1412+
proxies=proxies,
1413+
force_download=force_download,
1414+
token=token,
1415+
)
1416+
config_dict = cls._dict_from_json_file(config_file)
1417+
ignore_filenames = config_dict.pop("_ignore_files", [])
1418+
14081419
filenames = {sibling.rfilename for sibling in info.siblings}
14091420
if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
14101421
warn_msg = (
@@ -1419,61 +1430,20 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
14191430
)
14201431
logger.warning(warn_msg)
14211432

1422-
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
1423-
1424-
config_file = hf_hub_download(
1425-
pretrained_model_name,
1426-
cls.config_name,
1427-
cache_dir=cache_dir,
1428-
revision=revision,
1429-
proxies=proxies,
1430-
force_download=force_download,
1431-
token=token,
1432-
)
1433-
1434-
config_dict = cls._dict_from_json_file(config_file)
1435-
ignore_filenames = config_dict.pop("_ignore_files", [])
1436-
1437-
# remove ignored filenames
1438-
model_filenames = set(model_filenames) - set(ignore_filenames)
1439-
variant_filenames = set(variant_filenames) - set(ignore_filenames)
1440-
1433+
filenames = set(filenames) - set(ignore_filenames)
14411434
if revision in DEPRECATED_REVISION_ARGS and version.parse(
14421435
version.parse(__version__).base_version
14431436
) >= version.parse("0.22.0"):
1444-
warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames)
1437+
warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, filenames)
14451438

14461439
custom_components, folder_names = _get_custom_components_and_folders(
1447-
pretrained_model_name, config_dict, filenames, variant_filenames, variant
1440+
pretrained_model_name, config_dict, filenames, variant
14481441
)
1449-
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
1450-
14511442
custom_class_name = None
14521443
if custom_pipeline is None and isinstance(config_dict["_class_name"], (list, tuple)):
14531444
custom_pipeline = config_dict["_class_name"][0]
14541445
custom_class_name = config_dict["_class_name"][1]
14551446

1456-
# all filenames compatible with variant will be added
1457-
allow_patterns = list(model_filenames)
1458-
1459-
# allow all patterns from non-model folders
1460-
# this enables downloading schedulers, tokenizers, ...
1461-
allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
1462-
# add custom component files
1463-
allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
1464-
# add custom pipeline file
1465-
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
1466-
# also allow downloading config.json files with the model
1467-
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
1468-
# also allow downloading generation_config.json of the transformers model
1469-
allow_patterns += [os.path.join(k, "generation_config.json") for k in model_folder_names]
1470-
allow_patterns += [
1471-
SCHEDULER_CONFIG_NAME,
1472-
CONFIG_NAME,
1473-
cls.config_name,
1474-
CUSTOM_PIPELINE_FILE_NAME,
1475-
]
1476-
14771447
load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames
14781448
load_components_from_hub = len(custom_components) > 0
14791449

@@ -1506,12 +1476,15 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
15061476
expected_components, _ = cls._get_signature_keys(pipeline_class)
15071477
passed_components = [k for k in expected_components if k in kwargs]
15081478

1479+
# retrieve the names of the folders containing model weights
1480+
model_folder_names = {
1481+
os.path.split(f)[0] for f in filter_model_files(filenames) if os.path.split(f)[0] in folder_names
1482+
}
15091483
# retrieve all patterns that should not be downloaded and error out when needed
15101484
ignore_patterns = _get_ignore_patterns(
15111485
passed_components,
15121486
model_folder_names,
1513-
model_filenames,
1514-
variant_filenames,
1487+
filenames,
15151488
use_safetensors,
15161489
from_flax,
15171490
allow_pickle,
@@ -1520,6 +1493,29 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
15201493
variant,
15211494
)
15221495

1496+
model_filenames, variant_filenames = variant_compatible_siblings(
1497+
filenames, variant=variant, ignore_patterns=ignore_patterns
1498+
)
1499+
1500+
# all filenames compatible with variant will be added
1501+
allow_patterns = list(model_filenames)
1502+
1503+
# allow all patterns from non-model folders
1504+
# this enables downloading schedulers, tokenizers, ...
1505+
allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
1506+
# add custom component files
1507+
allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
1508+
# add custom pipeline file
1509+
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
1510+
# also allow downloading config.json files with the model
1511+
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
1512+
allow_patterns += [
1513+
SCHEDULER_CONFIG_NAME,
1514+
CONFIG_NAME,
1515+
cls.config_name,
1516+
CUSTOM_PIPELINE_FILE_NAME,
1517+
]
1518+
15231519
# Don't download any objects that are passed
15241520
allow_patterns = [
15251521
p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components)

0 commit comments

Comments
 (0)