92
92
ALL_IMPORTABLE_CLASSES .update (LOADABLE_CLASSES [library ])
93
93
94
94
95
- def is_safetensors_compatible (filenames , passed_components = None , folder_names = None ) -> bool :
95
+ def is_safetensors_compatible (filenames , passed_components = None , folder_names = None , variant = None ) -> bool :
96
96
"""
97
97
Checking for safetensors compatibility:
98
98
- The model is safetensors compatible only if there is a safetensors file for each model component present in
@@ -103,6 +103,31 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
103
103
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
104
104
extension is replaced with ".safetensors"
105
105
"""
106
+ weight_names = [
107
+ WEIGHTS_NAME ,
108
+ SAFETENSORS_WEIGHTS_NAME ,
109
+ FLAX_WEIGHTS_NAME ,
110
+ ONNX_WEIGHTS_NAME ,
111
+ ONNX_EXTERNAL_WEIGHTS_NAME ,
112
+ ]
113
+
114
+ if is_transformers_available ():
115
+ weight_names += [TRANSFORMERS_WEIGHTS_NAME , TRANSFORMERS_SAFE_WEIGHTS_NAME , TRANSFORMERS_FLAX_WEIGHTS_NAME ]
116
+
117
+ # model_pytorch, diffusion_model_pytorch, ...
118
+ weight_prefixes = [w .split ("." )[0 ] for w in weight_names ]
119
+ # .bin, .safetensors, ...
120
+ weight_suffixs = [w .split ("." )[- 1 ] for w in weight_names ]
121
+ # -00001-of-00002
122
+ transformers_index_format = r"\d{5}-of-\d{5}"
123
+ # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
124
+ variant_file_re = re .compile (
125
+ rf"({ '|' .join (weight_prefixes )} )\.({ variant } |{ variant } -{ transformers_index_format } )\.({ '|' .join (weight_suffixs )} )$"
126
+ )
127
+ non_variant_file_re = re .compile (
128
+ rf"({ '|' .join (weight_prefixes )} )(-{ transformers_index_format } )?\.({ '|' .join (weight_suffixs )} )$"
129
+ )
130
+
106
131
passed_components = passed_components or []
107
132
if folder_names :
108
133
filenames = {f for f in filenames if os .path .split (f )[0 ] in folder_names }
@@ -122,14 +147,22 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
122
147
123
148
# If there are no component folders check the main directory for safetensors files
124
149
if not components :
125
- return any (".safetensors" in filename for filename in filenames )
150
+ if variant is not None :
151
+ filtered_filenames = filter_with_regex (filenames , variant_file_re )
152
+ else :
153
+ filtered_filenames = filter_with_regex (filenames , non_variant_file_re )
154
+ return any (".safetensors" in filename for filename in filtered_filenames )
126
155
127
156
# iterate over all files of a component
128
157
# check if safetensor files exist for that component
129
158
# if variant is provided check if the variant of the safetensors exists
130
159
for component , component_filenames in components .items ():
131
160
matches = []
132
- for component_filename in component_filenames :
161
+ if variant is not None :
162
+ filtered_component_filenames = filter_with_regex (component_filenames , variant_file_re )
163
+ else :
164
+ filtered_component_filenames = filter_with_regex (component_filenames , non_variant_file_re )
165
+ for component_filename in filtered_component_filenames :
133
166
filename , extension = os .path .splitext (component_filename )
134
167
135
168
match_exists = extension == ".safetensors"
@@ -159,6 +192,10 @@ def filter_model_files(filenames):
159
192
return [f for f in filenames if any (f .endswith (extension ) for extension in allowed_extensions )]
160
193
161
194
195
+ def filter_with_regex (filenames , pattern_re ):
196
+ return {f for f in filenames if pattern_re .match (f .split ("/" )[- 1 ]) is not None }
197
+
198
+
162
199
def variant_compatible_siblings (filenames , variant = None , ignore_patterns = None ) -> Union [List [os .PathLike ], str ]:
163
200
weight_names = [
164
201
WEIGHTS_NAME ,
@@ -207,9 +244,6 @@ def filter_for_compatible_extensions(filenames, ignore_patterns=None):
207
244
# interested in the extension name
208
245
return {f for f in filenames if not any (f .endswith (pat .lstrip ("*." )) for pat in ignore_patterns )}
209
246
210
- def filter_with_regex (filenames , pattern_re ):
211
- return {f for f in filenames if pattern_re .match (f .split ("/" )[- 1 ]) is not None }
212
-
213
247
# Group files by component
214
248
components = {}
215
249
for filename in filenames :
@@ -997,7 +1031,7 @@ def _get_ignore_patterns(
997
1031
use_safetensors
998
1032
and not allow_pickle
999
1033
and not is_safetensors_compatible (
1000
- model_filenames , passed_components = passed_components , folder_names = model_folder_names
1034
+ model_filenames , passed_components = passed_components , folder_names = model_folder_names , variant = variant
1001
1035
)
1002
1036
):
1003
1037
raise EnvironmentError (
@@ -1008,7 +1042,7 @@ def _get_ignore_patterns(
1008
1042
ignore_patterns = ["*.bin" , "*.safetensors" , "*.onnx" , "*.pb" ]
1009
1043
1010
1044
elif use_safetensors and is_safetensors_compatible (
1011
- model_filenames , passed_components = passed_components , folder_names = model_folder_names
1045
+ model_filenames , passed_components = passed_components , folder_names = model_folder_names , variant = variant
1012
1046
):
1013
1047
ignore_patterns = ["*.bin" , "*.msgpack" ]
1014
1048
0 commit comments