@@ -538,38 +538,26 @@ def test_download_variant_partly(self):
538
538
variant = "no_ema"
539
539
540
540
with tempfile .TemporaryDirectory () as tmpdirname :
541
- if use_safetensors :
542
- with self .assertRaises (OSError ) as error_context :
543
- tmpdirname = StableDiffusionPipeline .download (
544
- "hf-internal-testing/stable-diffusion-all-variants" ,
545
- cache_dir = tmpdirname ,
546
- variant = variant ,
547
- use_safetensors = use_safetensors ,
548
- )
549
- assert "Could not find the necessary `safetensors` weights" in str (error_context .exception )
550
- else :
551
- tmpdirname = StableDiffusionPipeline .download (
552
- "hf-internal-testing/stable-diffusion-all-variants" ,
553
- cache_dir = tmpdirname ,
554
- variant = variant ,
555
- use_safetensors = use_safetensors ,
556
- )
557
- all_root_files = [t [- 1 ] for t in os .walk (tmpdirname )]
558
- files = [item for sublist in all_root_files for item in sublist ]
541
+ tmpdirname = StableDiffusionPipeline .download (
542
+ "hf-internal-testing/stable-diffusion-all-variants" ,
543
+ cache_dir = tmpdirname ,
544
+ variant = variant ,
545
+ use_safetensors = use_safetensors ,
546
+ )
547
+ all_root_files = [t [- 1 ] for t in os .walk (tmpdirname )]
548
+ files = [item for sublist in all_root_files for item in sublist ]
559
549
560
- unet_files = os .listdir (os .path .join (tmpdirname , "unet" ))
561
-
562
- # Some of the downloaded files should be a non-variant file, check:
563
- # https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
564
- assert len (files ) == 15 , f"We should only download 15 files, not { len (files )} "
565
- # only unet has "no_ema" variant
566
- assert f"diffusion_pytorch_model.{ variant } { this_format } " in unet_files
567
- assert len ([f for f in files if f .endswith (f"{ variant } { this_format } " )]) == 1
568
- # vae, safety_checker and text_encoder should have no variant
569
- assert (
570
- sum (f .endswith (this_format ) and not f .endswith (f"{ variant } { this_format } " ) for f in files ) == 3
571
- )
572
- assert not any (f .endswith (other_format ) for f in files )
550
+ unet_files = os .listdir (os .path .join (tmpdirname , "unet" ))
551
+
552
+ # Some of the downloaded files should be a non-variant file, check:
553
+ # https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
554
+ assert len (files ) == 15 , f"We should only download 15 files, not { len (files )} "
555
+ # only unet has "no_ema" variant
556
+ assert f"diffusion_pytorch_model.{ variant } { this_format } " in unet_files
557
+ assert len ([f for f in files if f .endswith (f"{ variant } { this_format } " )]) == 1
558
+ # vae, safety_checker and text_encoder should have no variant
559
+ assert sum (f .endswith (this_format ) and not f .endswith (f"{ variant } { this_format } " ) for f in files ) == 3
560
+ assert not any (f .endswith (other_format ) for f in files )
573
561
574
562
def test_download_variants_with_sharded_checkpoints (self ):
575
563
# Here we test for downloading of "variant" files belonging to the `unet` and
0 commit comments