Skip to content

Commit 1a6fa69

Browse files
PipelineTesterMixin parameter configuration refactor (#2502)
* attend and excite batch test causing timeouts * PipelineTesterMixin argument configuration refactor * error message text re: @yiyixuxu * remove eta re: @patrickvonplaten
1 parent 664b4de commit 1a6fa69

27 files changed

+374
-54
lines changed

tests/pipeline_params.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# These are canonical sets of parameters for different types of pipelines.
2+
# They are set on subclasses of `PipelineTesterMixin` as `params` and
3+
# `batch_params`.
4+
#
5+
# If a pipeline's set of arguments has minor changes from one of the common sets
6+
# of arguments, do not make modifications to the existing common sets of arguments.
7+
# I.e. a text to image pipeline with non-configurable height and width arguments
8+
# should set its attribute as `params = TEXT_TO_IMAGE_PARAMS - {'height', 'width'}`.
9+
10+
TEXT_TO_IMAGE_PARAMS = frozenset(
11+
[
12+
"prompt",
13+
"height",
14+
"width",
15+
"guidance_scale",
16+
"negative_prompt",
17+
"prompt_embeds",
18+
"negative_prompt_embeds",
19+
"cross_attention_kwargs",
20+
]
21+
)
22+
23+
TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
24+
25+
IMAGE_VARIATION_PARAMS = frozenset(
26+
[
27+
"image",
28+
"height",
29+
"width",
30+
"guidance_scale",
31+
]
32+
)
33+
34+
IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"])
35+
36+
TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset(
37+
[
38+
"prompt",
39+
"image",
40+
"height",
41+
"width",
42+
"guidance_scale",
43+
"negative_prompt",
44+
"prompt_embeds",
45+
"negative_prompt_embeds",
46+
]
47+
)
48+
49+
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"])
50+
51+
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
52+
[
53+
# Text guided image variation with an image mask
54+
"prompt",
55+
"image",
56+
"mask_image",
57+
"height",
58+
"width",
59+
"guidance_scale",
60+
"negative_prompt",
61+
"prompt_embeds",
62+
"negative_prompt_embeds",
63+
]
64+
)
65+
66+
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"])
67+
68+
IMAGE_INPAINTING_PARAMS = frozenset(
69+
[
70+
# image variation with an image mask
71+
"image",
72+
"mask_image",
73+
"height",
74+
"width",
75+
"guidance_scale",
76+
]
77+
)
78+
79+
IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"])
80+
81+
IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
82+
[
83+
"example_image",
84+
"image",
85+
"mask_image",
86+
"height",
87+
"width",
88+
"guidance_scale",
89+
]
90+
)
91+
92+
IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"])
93+
94+
CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS = frozenset(["class_labels"])
95+
96+
CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS = frozenset(["class_labels"])
97+
98+
UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"])
99+
100+
UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([])
101+
102+
UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])
103+
104+
UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])

tests/pipelines/altdiffusion/test_alt_diffusion.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from diffusers.utils import slow, torch_device
2929
from diffusers.utils.testing_utils import require_torch_gpu
3030

31+
from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
3132
from ...test_pipelines_common import PipelineTesterMixin
3233

3334

@@ -36,6 +37,8 @@
3637

3738
class AltDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
3839
pipeline_class = AltDiffusionPipeline
40+
params = TEXT_TO_IMAGE_PARAMS
41+
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
3942

4043
def get_dummy_components(self):
4144
torch.manual_seed(0)

tests/pipelines/dance_diffusion/test_dance_diffusion.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from diffusers.utils import slow, torch_device
2424
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
2525

26+
from ...pipeline_params import UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS, UNCONDITIONAL_AUDIO_GENERATION_PARAMS
2627
from ...test_pipelines_common import PipelineTesterMixin
2728

2829

@@ -31,6 +32,15 @@
3132

3233
class DanceDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
3334
pipeline_class = DanceDiffusionPipeline
35+
params = UNCONDITIONAL_AUDIO_GENERATION_PARAMS
36+
required_optional_params = PipelineTesterMixin.required_optional_params - {
37+
"callback",
38+
"latents",
39+
"callback_steps",
40+
"output_type",
41+
"num_images_per_prompt",
42+
}
43+
batch_params = UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS
3444
test_attention_slicing = False
3545
test_cpu_offload = False
3646

tests/pipelines/ddim/test_ddim.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from diffusers import DDIMPipeline, DDIMScheduler, UNet2DModel
2222
from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device
2323

24+
from ...pipeline_params import UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS, UNCONDITIONAL_IMAGE_GENERATION_PARAMS
2425
from ...test_pipelines_common import PipelineTesterMixin
2526

2627

@@ -29,6 +30,14 @@
2930

3031
class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
3132
pipeline_class = DDIMPipeline
33+
params = UNCONDITIONAL_IMAGE_GENERATION_PARAMS
34+
required_optional_params = PipelineTesterMixin.required_optional_params - {
35+
"num_images_per_prompt",
36+
"latents",
37+
"callback",
38+
"callback_steps",
39+
}
40+
batch_params = UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS
3241
test_cpu_offload = False
3342

3443
def get_dummy_components(self):

tests/pipelines/dit/test_dit.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
from diffusers.utils import load_numpy, slow
2424
from diffusers.utils.testing_utils import require_torch_gpu
2525

26+
from ...pipeline_params import (
27+
CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS,
28+
CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS,
29+
)
2630
from ...test_pipelines_common import PipelineTesterMixin
2731

2832

@@ -31,6 +35,14 @@
3135

3236
class DiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
3337
pipeline_class = DiTPipeline
38+
params = CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS
39+
required_optional_params = PipelineTesterMixin.required_optional_params - {
40+
"latents",
41+
"num_images_per_prompt",
42+
"callback",
43+
"callback_steps",
44+
}
45+
batch_params = CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS
3446
test_cpu_offload = False
3547

3648
def get_dummy_components(self):

tests/pipelines/latent_diffusion/test_latent_diffusion.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from diffusers import AutoencoderKL, DDIMScheduler, LDMTextToImagePipeline, UNet2DConditionModel
2424
from diffusers.utils.testing_utils import load_numpy, nightly, require_torch_gpu, slow, torch_device
2525

26+
from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
2627
from ...test_pipelines_common import PipelineTesterMixin
2728

2829

@@ -31,6 +32,18 @@
3132

3233
class LDMTextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
3334
pipeline_class = LDMTextToImagePipeline
35+
params = TEXT_TO_IMAGE_PARAMS - {
36+
"negative_prompt",
37+
"negative_prompt_embeds",
38+
"cross_attention_kwargs",
39+
"prompt_embeds",
40+
}
41+
required_optional_params = PipelineTesterMixin.required_optional_params - {
42+
"num_images_per_prompt",
43+
"callback",
44+
"callback_steps",
45+
}
46+
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
3447
test_cpu_offload = False
3548

3649
def get_dummy_components(self):

tests/pipelines/paint_by_example/test_paint_by_example.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from diffusers.utils import floats_tensor, load_image, slow, torch_device
2828
from diffusers.utils.testing_utils import require_torch_gpu
2929

30+
from ...pipeline_params import IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS
3031
from ...test_pipelines_common import PipelineTesterMixin
3132

3233

@@ -35,12 +36,8 @@
3536

3637
class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
3738
pipeline_class = PaintByExamplePipeline
38-
39-
def tearDown(self):
40-
# clean up the VRAM after each test
41-
super().tearDown()
42-
gc.collect()
43-
torch.cuda.empty_cache()
39+
params = IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS
40+
batch_params = IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
4441

4542
def get_dummy_components(self):
4643
torch.manual_seed(0)

tests/pipelines/repaint/test_repaint.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from diffusers import RePaintPipeline, RePaintScheduler, UNet2DModel
2323
from diffusers.utils.testing_utils import load_image, load_numpy, nightly, require_torch_gpu, skip_mps, torch_device
2424

25+
from ...pipeline_params import IMAGE_INPAINTING_BATCH_PARAMS, IMAGE_INPAINTING_PARAMS
2526
from ...test_pipelines_common import PipelineTesterMixin
2627

2728

@@ -30,6 +31,14 @@
3031

3132
class RepaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
3233
pipeline_class = RePaintPipeline
34+
params = IMAGE_INPAINTING_PARAMS - {"width", "height", "guidance_scale"}
35+
required_optional_params = PipelineTesterMixin.required_optional_params - {
36+
"latents",
37+
"num_images_per_prompt",
38+
"callback",
39+
"callback_steps",
40+
}
41+
batch_params = IMAGE_INPAINTING_BATCH_PARAMS
3342
test_cpu_offload = False
3443

3544
def get_dummy_components(self):

tests/pipelines/stable_diffusion/test_cycle_diffusion.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
2626
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
2727

28+
from ...pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
2829
from ...test_pipelines_common import PipelineTesterMixin
2930

3031

@@ -33,6 +34,14 @@
3334

3435
class CycleDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
3536
pipeline_class = CycleDiffusionPipeline
37+
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {
38+
"negative_prompt",
39+
"height",
40+
"width",
41+
"negative_prompt_embeds",
42+
}
43+
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
44+
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
3645

3746
def get_dummy_components(self):
3847
torch.manual_seed(0)

tests/pipelines/stable_diffusion/test_stable_diffusion.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
4040

4141
from ...models.test_models_unet_2d_condition import create_lora_layers
42+
from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
4243
from ...test_pipelines_common import PipelineTesterMixin
4344

4445

@@ -47,6 +48,8 @@
4748

4849
class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
4950
pipeline_class = StableDiffusionPipeline
51+
params = TEXT_TO_IMAGE_PARAMS
52+
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
5053

5154
def get_dummy_components(self):
5255
torch.manual_seed(0)

tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
3333
from diffusers.utils.testing_utils import require_torch_gpu
3434

35+
from ...pipeline_params import IMAGE_VARIATION_BATCH_PARAMS, IMAGE_VARIATION_PARAMS
3536
from ...test_pipelines_common import PipelineTesterMixin
3637

3738

@@ -40,6 +41,8 @@
4041

4142
class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
4243
pipeline_class = StableDiffusionImageVariationPipeline
44+
params = IMAGE_VARIATION_PARAMS
45+
batch_params = IMAGE_VARIATION_BATCH_PARAMS
4346

4447
def get_dummy_components(self):
4548
torch.manual_seed(0)

tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
3434
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
3535

36+
from ...pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
3637
from ...test_pipelines_common import PipelineTesterMixin
3738

3839

@@ -41,6 +42,9 @@
4142

4243
class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
4344
pipeline_class = StableDiffusionImg2ImgPipeline
45+
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
46+
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
47+
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
4448

4549
def get_dummy_components(self):
4650
torch.manual_seed(0)

tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
3535
from diffusers.utils.testing_utils import require_torch_gpu
3636

37+
from ...pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
3738
from ...test_pipelines_common import PipelineTesterMixin
3839

3940

@@ -42,6 +43,8 @@
4243

4344
class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
4445
pipeline_class = StableDiffusionInpaintPipeline
46+
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
47+
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
4548

4649
def get_dummy_components(self):
4750
torch.manual_seed(0)

tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from diffusers.utils import floats_tensor, load_image, slow, torch_device
3535
from diffusers.utils.testing_utils import require_torch_gpu
3636

37+
from ...pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
3738
from ...test_pipelines_common import PipelineTesterMixin
3839

3940

@@ -42,6 +43,8 @@
4243

4344
class StableDiffusionInstructPix2PixPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
4445
pipeline_class = StableDiffusionInstructPix2PixPipeline
46+
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "cross_attention_kwargs"}
47+
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
4548

4649
def get_dummy_components(self):
4750
torch.manual_seed(0)

tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from diffusers.utils import slow, torch_device
3333
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
3434

35+
from ...pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
3536
from ...test_pipelines_common import PipelineTesterMixin
3637

3738

@@ -41,6 +42,8 @@
4142
@skip_mps
4243
class StableDiffusionPanoramaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
4344
pipeline_class = StableDiffusionPanoramaPipeline
45+
params = TEXT_TO_IMAGE_PARAMS
46+
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
4447

4548
def get_dummy_components(self):
4649
torch.manual_seed(0)

0 commit comments

Comments
 (0)