Skip to content

Commit 72e69ca

Browse files
DN6sayakpaul
andcommitted
Improve downloads of sharded variants (#9869)
* update * update * update * update --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent cb7016c commit 72e69ca

File tree

2 files changed

+155
-5
lines changed

2 files changed

+155
-5
lines changed

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,10 +198,31 @@ def convert_to_variant(filename):
198198
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
199199
return variant_filename
200200

201-
for f in non_variant_filenames:
202-
variant_filename = convert_to_variant(f)
203-
if variant_filename not in usable_filenames:
204-
usable_filenames.add(f)
201+
def find_component(filename):
202+
if not len(filename.split("/")) == 2:
203+
return
204+
component = filename.split("/")[0]
205+
return component
206+
207+
def has_sharded_variant(component, variant, variant_filenames):
208+
# If component exists check for sharded variant index filename
209+
# If component doesn't exist check main dir for sharded variant index filename
210+
component = component + "/" if component else ""
211+
variant_index_re = re.compile(
212+
rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
213+
)
214+
return any(f for f in variant_filenames if variant_index_re.match(f) is not None)
215+
216+
for filename in non_variant_filenames:
217+
if convert_to_variant(filename) in variant_filenames:
218+
continue
219+
220+
component = find_component(filename)
221+
# If a sharded variant exists skip adding to allowed patterns
222+
if has_sharded_variant(component, variant, variant_filenames):
223+
continue
224+
225+
usable_filenames.add(filename)
205226

206227
return usable_filenames, variant_filenames
207228

tests/pipelines/test_pipeline_utils.py

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
StableDiffusionPipeline,
1919
UNet2DConditionModel,
2020
)
21-
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible
21+
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings
2222
from diffusers.utils.testing_utils import torch_device
2323

2424

@@ -210,6 +210,135 @@ def test_diffusers_is_compatible_no_components_only_variants(self):
210210
self.assertFalse(is_safetensors_compatible(filenames))
211211

212212

213+
class VariantCompatibleSiblingsTest(unittest.TestCase):
214+
def test_only_non_variants_downloaded(self):
215+
variant = "fp16"
216+
filenames = [
217+
f"vae/diffusion_pytorch_model.{variant}.safetensors",
218+
"vae/diffusion_pytorch_model.safetensors",
219+
f"text_encoder/model.{variant}.safetensors",
220+
"text_encoder/model.safetensors",
221+
f"unet/diffusion_pytorch_model.{variant}.safetensors",
222+
"unet/diffusion_pytorch_model.safetensors",
223+
]
224+
225+
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
226+
assert all(variant not in f for f in model_filenames)
227+
228+
def test_only_variants_downloaded(self):
229+
variant = "fp16"
230+
filenames = [
231+
f"vae/diffusion_pytorch_model.{variant}.safetensors",
232+
"vae/diffusion_pytorch_model.safetensors",
233+
f"text_encoder/model.{variant}.safetensors",
234+
"text_encoder/model.safetensors",
235+
f"unet/diffusion_pytorch_model.{variant}.safetensors",
236+
"unet/diffusion_pytorch_model.safetensors",
237+
]
238+
239+
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
240+
assert all(variant in f for f in model_filenames)
241+
242+
def test_mixed_variants_downloaded(self):
243+
variant = "fp16"
244+
non_variant_file = "text_encoder/model.safetensors"
245+
filenames = [
246+
f"vae/diffusion_pytorch_model.{variant}.safetensors",
247+
"vae/diffusion_pytorch_model.safetensors",
248+
"text_encoder/model.safetensors",
249+
f"unet/diffusion_pytorch_model.{variant}.safetensors",
250+
"unet/diffusion_pytorch_model.safetensors",
251+
]
252+
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
253+
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
254+
255+
def test_non_variants_in_main_dir_downloaded(self):
256+
variant = "fp16"
257+
filenames = [
258+
f"diffusion_pytorch_model.{variant}.safetensors",
259+
"diffusion_pytorch_model.safetensors",
260+
"model.safetensors",
261+
f"model.{variant}.safetensors",
262+
f"diffusion_pytorch_model.{variant}.safetensors",
263+
"diffusion_pytorch_model.safetensors",
264+
]
265+
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
266+
assert all(variant not in f for f in model_filenames)
267+
268+
def test_variants_in_main_dir_downloaded(self):
269+
variant = "fp16"
270+
filenames = [
271+
f"diffusion_pytorch_model.{variant}.safetensors",
272+
"diffusion_pytorch_model.safetensors",
273+
"model.safetensors",
274+
f"model.{variant}.safetensors",
275+
f"diffusion_pytorch_model.{variant}.safetensors",
276+
"diffusion_pytorch_model.safetensors",
277+
]
278+
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
279+
assert all(variant in f for f in model_filenames)
280+
281+
def test_mixed_variants_in_main_dir_downloaded(self):
282+
variant = "fp16"
283+
non_variant_file = "model.safetensors"
284+
filenames = [
285+
f"diffusion_pytorch_model.{variant}.safetensors",
286+
"diffusion_pytorch_model.safetensors",
287+
"model.safetensors",
288+
f"diffusion_pytorch_model.{variant}.safetensors",
289+
"diffusion_pytorch_model.safetensors",
290+
]
291+
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
292+
assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames)
293+
294+
def test_sharded_non_variants_downloaded(self):
295+
variant = "fp16"
296+
filenames = [
297+
f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
298+
"unet/diffusion_pytorch_model.safetensors.index.json",
299+
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
300+
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
301+
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
302+
f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
303+
f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
304+
]
305+
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None)
306+
assert all(variant not in f for f in model_filenames)
307+
308+
def test_sharded_variants_downloaded(self):
309+
variant = "fp16"
310+
filenames = [
311+
f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json",
312+
"unet/diffusion_pytorch_model.safetensors.index.json",
313+
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
314+
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
315+
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
316+
f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
317+
f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
318+
]
319+
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
320+
assert all(variant in f for f in model_filenames)
321+
322+
def test_sharded_mixed_variants_downloaded(self):
323+
variant = "fp16"
324+
allowed_non_variant = "unet"
325+
filenames = [
326+
f"vae/diffusion_pytorch_model.safetensors.index.{variant}.json",
327+
"vae/diffusion_pytorch_model.safetensors.index.json",
328+
"unet/diffusion_pytorch_model.safetensors.index.json",
329+
"unet/diffusion_pytorch_model-00001-of-00003.safetensors",
330+
"unet/diffusion_pytorch_model-00002-of-00003.safetensors",
331+
"unet/diffusion_pytorch_model-00003-of-00003.safetensors",
332+
f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors",
333+
f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors",
334+
"vae/diffusion_pytorch_model-00001-of-00003.safetensors",
335+
"vae/diffusion_pytorch_model-00002-of-00003.safetensors",
336+
"vae/diffusion_pytorch_model-00003-of-00003.safetensors",
337+
]
338+
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
339+
assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames)
340+
341+
213342
class ProgressBarTests(unittest.TestCase):
214343
def get_dummy_components_image_generation(self):
215344
cross_attention_dim = 8

0 commit comments

Comments
 (0)