|
18 | 18 | StableDiffusionPipeline,
|
19 | 19 | UNet2DConditionModel,
|
20 | 20 | )
|
21 |
| -from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible |
| 21 | +from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings |
22 | 22 | from diffusers.utils.testing_utils import torch_device
|
23 | 23 |
|
24 | 24 |
|
@@ -210,6 +210,135 @@ def test_diffusers_is_compatible_no_components_only_variants(self):
|
210 | 210 | self.assertFalse(is_safetensors_compatible(filenames))
|
211 | 211 |
|
212 | 212 |
|
| 213 | +class VariantCompatibleSiblingsTest(unittest.TestCase): |
| 214 | + def test_only_non_variants_downloaded(self): |
| 215 | + variant = "fp16" |
| 216 | + filenames = [ |
| 217 | + f"vae/diffusion_pytorch_model.{variant}.safetensors", |
| 218 | + "vae/diffusion_pytorch_model.safetensors", |
| 219 | + f"text_encoder/model.{variant}.safetensors", |
| 220 | + "text_encoder/model.safetensors", |
| 221 | + f"unet/diffusion_pytorch_model.{variant}.safetensors", |
| 222 | + "unet/diffusion_pytorch_model.safetensors", |
| 223 | + ] |
| 224 | + |
| 225 | + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None) |
| 226 | + assert all(variant not in f for f in model_filenames) |
| 227 | + |
| 228 | + def test_only_variants_downloaded(self): |
| 229 | + variant = "fp16" |
| 230 | + filenames = [ |
| 231 | + f"vae/diffusion_pytorch_model.{variant}.safetensors", |
| 232 | + "vae/diffusion_pytorch_model.safetensors", |
| 233 | + f"text_encoder/model.{variant}.safetensors", |
| 234 | + "text_encoder/model.safetensors", |
| 235 | + f"unet/diffusion_pytorch_model.{variant}.safetensors", |
| 236 | + "unet/diffusion_pytorch_model.safetensors", |
| 237 | + ] |
| 238 | + |
| 239 | + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) |
| 240 | + assert all(variant in f for f in model_filenames) |
| 241 | + |
| 242 | + def test_mixed_variants_downloaded(self): |
| 243 | + variant = "fp16" |
| 244 | + non_variant_file = "text_encoder/model.safetensors" |
| 245 | + filenames = [ |
| 246 | + f"vae/diffusion_pytorch_model.{variant}.safetensors", |
| 247 | + "vae/diffusion_pytorch_model.safetensors", |
| 248 | + "text_encoder/model.safetensors", |
| 249 | + f"unet/diffusion_pytorch_model.{variant}.safetensors", |
| 250 | + "unet/diffusion_pytorch_model.safetensors", |
| 251 | + ] |
| 252 | + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) |
| 253 | + assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames) |
| 254 | + |
| 255 | + def test_non_variants_in_main_dir_downloaded(self): |
| 256 | + variant = "fp16" |
| 257 | + filenames = [ |
| 258 | + f"diffusion_pytorch_model.{variant}.safetensors", |
| 259 | + "diffusion_pytorch_model.safetensors", |
| 260 | + "model.safetensors", |
| 261 | + f"model.{variant}.safetensors", |
| 262 | + f"diffusion_pytorch_model.{variant}.safetensors", |
| 263 | + "diffusion_pytorch_model.safetensors", |
| 264 | + ] |
| 265 | + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None) |
| 266 | + assert all(variant not in f for f in model_filenames) |
| 267 | + |
| 268 | + def test_variants_in_main_dir_downloaded(self): |
| 269 | + variant = "fp16" |
| 270 | + filenames = [ |
| 271 | + f"diffusion_pytorch_model.{variant}.safetensors", |
| 272 | + "diffusion_pytorch_model.safetensors", |
| 273 | + "model.safetensors", |
| 274 | + f"model.{variant}.safetensors", |
| 275 | + f"diffusion_pytorch_model.{variant}.safetensors", |
| 276 | + "diffusion_pytorch_model.safetensors", |
| 277 | + ] |
| 278 | + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) |
| 279 | + assert all(variant in f for f in model_filenames) |
| 280 | + |
| 281 | + def test_mixed_variants_in_main_dir_downloaded(self): |
| 282 | + variant = "fp16" |
| 283 | + non_variant_file = "model.safetensors" |
| 284 | + filenames = [ |
| 285 | + f"diffusion_pytorch_model.{variant}.safetensors", |
| 286 | + "diffusion_pytorch_model.safetensors", |
| 287 | + "model.safetensors", |
| 288 | + f"diffusion_pytorch_model.{variant}.safetensors", |
| 289 | + "diffusion_pytorch_model.safetensors", |
| 290 | + ] |
| 291 | + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) |
| 292 | + assert all(variant in f if f != non_variant_file else variant not in f for f in model_filenames) |
| 293 | + |
| 294 | + def test_sharded_non_variants_downloaded(self): |
| 295 | + variant = "fp16" |
| 296 | + filenames = [ |
| 297 | + f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json", |
| 298 | + "unet/diffusion_pytorch_model.safetensors.index.json", |
| 299 | + "unet/diffusion_pytorch_model-00001-of-00003.safetensors", |
| 300 | + "unet/diffusion_pytorch_model-00002-of-00003.safetensors", |
| 301 | + "unet/diffusion_pytorch_model-00003-of-00003.safetensors", |
| 302 | + f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors", |
| 303 | + f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors", |
| 304 | + ] |
| 305 | + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=None) |
| 306 | + assert all(variant not in f for f in model_filenames) |
| 307 | + |
| 308 | + def test_sharded_variants_downloaded(self): |
| 309 | + variant = "fp16" |
| 310 | + filenames = [ |
| 311 | + f"unet/diffusion_pytorch_model.safetensors.index.{variant}.json", |
| 312 | + "unet/diffusion_pytorch_model.safetensors.index.json", |
| 313 | + "unet/diffusion_pytorch_model-00001-of-00003.safetensors", |
| 314 | + "unet/diffusion_pytorch_model-00002-of-00003.safetensors", |
| 315 | + "unet/diffusion_pytorch_model-00003-of-00003.safetensors", |
| 316 | + f"unet/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors", |
| 317 | + f"unet/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors", |
| 318 | + ] |
| 319 | + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) |
| 320 | + assert all(variant in f for f in model_filenames) |
| 321 | + |
| 322 | + def test_sharded_mixed_variants_downloaded(self): |
| 323 | + variant = "fp16" |
| 324 | + allowed_non_variant = "unet" |
| 325 | + filenames = [ |
| 326 | + f"vae/diffusion_pytorch_model.safetensors.index.{variant}.json", |
| 327 | + "vae/diffusion_pytorch_model.safetensors.index.json", |
| 328 | + "unet/diffusion_pytorch_model.safetensors.index.json", |
| 329 | + "unet/diffusion_pytorch_model-00001-of-00003.safetensors", |
| 330 | + "unet/diffusion_pytorch_model-00002-of-00003.safetensors", |
| 331 | + "unet/diffusion_pytorch_model-00003-of-00003.safetensors", |
| 332 | + f"vae/diffusion_pytorch_model.{variant}-00001-of-00002.safetensors", |
| 333 | + f"vae/diffusion_pytorch_model.{variant}-00002-of-00002.safetensors", |
| 334 | + "vae/diffusion_pytorch_model-00001-of-00003.safetensors", |
| 335 | + "vae/diffusion_pytorch_model-00002-of-00003.safetensors", |
| 336 | + "vae/diffusion_pytorch_model-00003-of-00003.safetensors", |
| 337 | + ] |
| 338 | + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) |
| 339 | + assert all(variant in f if allowed_non_variant not in f else variant not in f for f in model_filenames) |
| 340 | + |
| 341 | + |
213 | 342 | class ProgressBarTests(unittest.TestCase):
|
214 | 343 | def get_dummy_components_image_generation(self):
|
215 | 344 | cross_attention_dim = 8
|
|
0 commit comments