@@ -104,7 +104,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
104
104
extension is replaced with ".safetensors"
105
105
"""
106
106
passed_components = passed_components or []
107
- if folder_names is not None :
107
+ if folder_names :
108
108
filenames = {f for f in filenames if os .path .split (f )[0 ] in folder_names }
109
109
110
110
# extract all components of the pipeline and their associated files
@@ -141,7 +141,25 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
141
141
return True
142
142
143
143
144
- def variant_compatible_siblings (filenames , variant = None ) -> Union [List [os .PathLike ], str ]:
144
+ def filter_model_files (filenames ):
145
+ """Filter model repo files for just files/folders that contain model weights"""
146
+ weight_names = [
147
+ WEIGHTS_NAME ,
148
+ SAFETENSORS_WEIGHTS_NAME ,
149
+ FLAX_WEIGHTS_NAME ,
150
+ ONNX_WEIGHTS_NAME ,
151
+ ONNX_EXTERNAL_WEIGHTS_NAME ,
152
+ ]
153
+
154
+ if is_transformers_available ():
155
+ weight_names += [TRANSFORMERS_WEIGHTS_NAME , TRANSFORMERS_SAFE_WEIGHTS_NAME , TRANSFORMERS_FLAX_WEIGHTS_NAME ]
156
+
157
+ allowed_extensions = [wn .split ("." )[- 1 ] for wn in weight_names ]
158
+
159
+ return [f for f in filenames if any (f .endswith (extension ) for extension in allowed_extensions )]
160
+
161
+
162
+ def variant_compatible_siblings (filenames , variant = None , ignore_patterns = None ) -> Union [List [os .PathLike ], str ]:
145
163
weight_names = [
146
164
WEIGHTS_NAME ,
147
165
SAFETENSORS_WEIGHTS_NAME ,
@@ -169,6 +187,10 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
169
187
variant_index_re = re .compile (
170
188
rf"({ '|' .join (weight_prefixes )} )\.({ '|' .join (weight_suffixs )} )\.index\.{ variant } \.json$"
171
189
)
190
+ legacy_variant_file_re = re .compile (rf".*-{ transformers_index_format } \.{ variant } \.[a-z]+$" )
191
+ legacy_variant_index_re = re .compile (
192
+ rf"({ '|' .join (weight_prefixes )} )\.({ '|' .join (weight_suffixs )} )\.{ variant } \.index\.json$"
193
+ )
172
194
173
195
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
174
196
non_variant_file_re = re .compile (
@@ -177,54 +199,68 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
177
199
# `text_encoder/pytorch_model.bin.index.json`
178
200
non_variant_index_re = re .compile (rf"({ '|' .join (weight_prefixes )} )\.({ '|' .join (weight_suffixs )} )\.index\.json" )
179
201
180
- if variant is not None :
181
- variant_weights = {f for f in filenames if variant_file_re .match (f .split ("/" )[- 1 ]) is not None }
182
- variant_indexes = {f for f in filenames if variant_index_re .match (f .split ("/" )[- 1 ]) is not None }
183
- variant_filenames = variant_weights | variant_indexes
184
- else :
185
- variant_filenames = set ()
202
+ def filter_for_compatible_extensions (filenames , ignore_patterns = None ):
203
+ if not ignore_patterns :
204
+ return filenames
205
+
206
+ # ignore patterns uses glob style patterns e.g *.safetensors but we're only
207
+ # interested in the extension name
208
+ return {f for f in filenames if not any (f .endswith (pat .lstrip ("*." )) for pat in ignore_patterns )}
209
+
210
+ def filter_with_regex (filenames , pattern_re ):
211
+ return {f for f in filenames if pattern_re .match (f .split ("/" )[- 1 ]) is not None }
212
+
213
+ # Group files by component
214
+ components = {}
215
+ for filename in filenames :
216
+ if not len (filename .split ("/" )) == 2 :
217
+ components .setdefault ("" , []).append (filename )
218
+ continue
186
219
187
- non_variant_weights = {f for f in filenames if non_variant_file_re .match (f .split ("/" )[- 1 ]) is not None }
188
- non_variant_indexes = {f for f in filenames if non_variant_index_re .match (f .split ("/" )[- 1 ]) is not None }
189
- non_variant_filenames = non_variant_weights | non_variant_indexes
220
+ component , _ = filename .split ("/" )
221
+ components .setdefault (component , []).append (filename )
190
222
191
- # all variant filenames will be used by default
192
- usable_filenames = set (variant_filenames )
223
+ usable_filenames = set ()
224
+ variant_filenames = set ()
225
+ for component , component_filenames in components .items ():
226
+ component_filenames = filter_for_compatible_extensions (component_filenames , ignore_patterns = ignore_patterns )
227
+
228
+ component_variants = set ()
229
+ component_legacy_variants = set ()
230
+ component_non_variants = set ()
231
+ if variant is not None :
232
+ component_variants = filter_with_regex (component_filenames , variant_file_re )
233
+ component_variant_index_files = filter_with_regex (component_filenames , variant_index_re )
234
+
235
+ component_legacy_variants = filter_with_regex (component_filenames , legacy_variant_file_re )
236
+ component_legacy_variant_index_files = filter_with_regex (component_filenames , legacy_variant_index_re )
237
+
238
+ if component_variants or component_legacy_variants :
239
+ variant_filenames .update (
240
+ component_variants | component_variant_index_files
241
+ if component_variants
242
+ else component_legacy_variants | component_legacy_variant_index_files
243
+ )
193
244
194
- def convert_to_variant (filename ):
195
- if "index" in filename :
196
- variant_filename = filename .replace ("index" , f"index.{ variant } " )
197
- elif re .compile (f"^(.*?){ transformers_index_format } " ).match (filename ) is not None :
198
- variant_filename = f"{ filename .split ('-' )[0 ]} .{ variant } -{ '-' .join (filename .split ('-' )[1 :])} "
199
245
else :
200
- variant_filename = f" { filename . split ( '.' )[ 0 ] } . { variant } . { filename . split ( '.' )[ 1 ] } "
201
- return variant_filename
246
+ component_non_variants = filter_with_regex ( component_filenames , non_variant_file_re )
247
+ component_variant_index_files = filter_with_regex ( component_filenames , non_variant_index_re )
202
248
203
- def find_component (filename ):
204
- if not len (filename .split ("/" )) == 2 :
205
- return
206
- component = filename .split ("/" )[0 ]
207
- return component
208
-
209
- def has_sharded_variant (component , variant , variant_filenames ):
210
- # If component exists check for sharded variant index filename
211
- # If component doesn't exist check main dir for sharded variant index filename
212
- component = component + "/" if component else ""
213
- variant_index_re = re .compile (
214
- rf"{ component } ({ '|' .join (weight_prefixes )} )\.({ '|' .join (weight_suffixs )} )\.index\.{ variant } \.json$"
215
- )
216
- return any (f for f in variant_filenames if variant_index_re .match (f ) is not None )
249
+ usable_filenames .update (component_non_variants | component_variant_index_files )
217
250
218
- for filename in non_variant_filenames :
219
- if convert_to_variant (filename ) in variant_filenames :
220
- continue
251
+ usable_filenames .update (variant_filenames )
221
252
222
- component = find_component (filename )
223
- # If a sharded variant exists skip adding to allowed patterns
224
- if has_sharded_variant (component , variant , variant_filenames ):
225
- continue
253
+ if len (variant_filenames ) == 0 and variant is not None :
254
+ error_message = f"You are trying to load model files of the `variant={ variant } `, but no such modeling files are available. "
255
+ raise ValueError (error_message )
226
256
227
- usable_filenames .add (filename )
257
+ if len (variant_filenames ) > 0 and usable_filenames != variant_filenames :
258
+ logger .warning (
259
+ f"\n A mixture of { variant } and non-{ variant } filenames will be loaded.\n Loaded { variant } filenames:\n "
260
+ f"[{ ', ' .join (variant_filenames )} ]\n Loaded non-{ variant } filenames:\n "
261
+ f"[{ ', ' .join (usable_filenames - variant_filenames )} \n If this behavior is not "
262
+ f"expected, please check your folder structure."
263
+ )
228
264
229
265
return usable_filenames , variant_filenames
230
266
@@ -922,18 +958,13 @@ def _get_custom_components_and_folders(
922
958
f"{ candidate_file } as defined in `model_index.json` does not exist in { pretrained_model_name } and is not a module in 'diffusers/pipelines'."
923
959
)
924
960
925
- if len (variant_filenames ) == 0 and variant is not None :
926
- error_message = f"You are trying to load the model files of the `variant={ variant } `, but no such modeling files are available."
927
- raise ValueError (error_message )
928
-
929
961
return custom_components , folder_names
930
962
931
963
932
964
def _get_ignore_patterns (
933
965
passed_components ,
934
966
model_folder_names : List [str ],
935
967
model_filenames : List [str ],
936
- variant_filenames : List [str ],
937
968
use_safetensors : bool ,
938
969
from_flax : bool ,
939
970
allow_pickle : bool ,
@@ -964,33 +995,13 @@ def _get_ignore_patterns(
964
995
if not use_onnx :
965
996
ignore_patterns += ["*.onnx" , "*.pb" ]
966
997
967
- safetensors_variant_filenames = {f for f in variant_filenames if f .endswith (".safetensors" )}
968
- safetensors_model_filenames = {f for f in model_filenames if f .endswith (".safetensors" )}
969
- if len (safetensors_variant_filenames ) > 0 and safetensors_model_filenames != safetensors_variant_filenames :
970
- logger .warning (
971
- f"\n A mixture of { variant } and non-{ variant } filenames will be loaded.\n Loaded { variant } filenames:\n "
972
- f"[{ ', ' .join (safetensors_variant_filenames )} ]\n Loaded non-{ variant } filenames:\n "
973
- f"[{ ', ' .join (safetensors_model_filenames - safetensors_variant_filenames )} \n If this behavior is not "
974
- f"expected, please check your folder structure."
975
- )
976
-
977
998
else :
978
999
ignore_patterns = ["*.safetensors" , "*.msgpack" ]
979
1000
980
1001
use_onnx = use_onnx if use_onnx is not None else is_onnx
981
1002
if not use_onnx :
982
1003
ignore_patterns += ["*.onnx" , "*.pb" ]
983
1004
984
- bin_variant_filenames = {f for f in variant_filenames if f .endswith (".bin" )}
985
- bin_model_filenames = {f for f in model_filenames if f .endswith (".bin" )}
986
- if len (bin_variant_filenames ) > 0 and bin_model_filenames != bin_variant_filenames :
987
- logger .warning (
988
- f"\n A mixture of { variant } and non-{ variant } filenames will be loaded.\n Loaded { variant } filenames:\n "
989
- f"[{ ', ' .join (bin_variant_filenames )} ]\n Loaded non-{ variant } filenames:\n "
990
- f"[{ ', ' .join (bin_model_filenames - bin_variant_filenames )} \n If this behavior is not expected, please check "
991
- f"your folder structure."
992
- )
993
-
994
1005
return ignore_patterns
995
1006
996
1007
0 commit comments