|
11 | 11 | UNet2DModel,
|
12 | 12 | )
|
13 | 13 | from diffusers.utils.testing_utils import (
|
| 14 | + Expectations, |
| 15 | + backend_empty_cache, |
14 | 16 | enable_full_determinism,
|
15 | 17 | nightly,
|
16 | 18 | require_torch_2,
|
17 |
| - require_torch_gpu, |
| 19 | + require_torch_accelerator, |
18 | 20 | torch_device,
|
19 | 21 | )
|
20 | 22 | from diffusers.utils.torch_utils import randn_tensor
|
@@ -168,17 +170,17 @@ def test_consistency_model_pipeline_onestep_class_cond(self):
|
168 | 170 |
|
169 | 171 |
|
170 | 172 | @nightly
|
171 |
| -@require_torch_gpu |
| 173 | +@require_torch_accelerator |
172 | 174 | class ConsistencyModelPipelineSlowTests(unittest.TestCase):
|
173 | 175 | def setUp(self):
|
174 | 176 | super().setUp()
|
175 | 177 | gc.collect()
|
176 |
| - torch.cuda.empty_cache() |
| 178 | + backend_empty_cache(torch_device) |
177 | 179 |
|
178 | 180 | def tearDown(self):
|
179 | 181 | super().tearDown()
|
180 | 182 | gc.collect()
|
181 |
| - torch.cuda.empty_cache() |
| 183 | + backend_empty_cache(torch_device) |
182 | 184 |
|
183 | 185 | def get_inputs(self, seed=0, get_fixed_latents=False, device="cpu", dtype=torch.float32, shape=(1, 3, 64, 64)):
|
184 | 186 | generator = torch.manual_seed(seed)
|
@@ -264,11 +266,19 @@ def test_consistency_model_cd_multistep_flash_attn(self):
|
264 | 266 | # Ensure usage of flash attention in torch 2.0
|
265 | 267 | with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
266 | 268 | image = pipe(**inputs).images
|
| 269 | + |
267 | 270 | assert image.shape == (1, 64, 64, 3)
|
268 | 271 |
|
269 | 272 | image_slice = image[0, -3:, -3:, -1]
|
270 | 273 |
|
271 |
| - expected_slice = np.array([0.1845, 0.1371, 0.1211, 0.2035, 0.1954, 0.1323, 0.1773, 0.1593, 0.1314]) |
| 274 | + expected_slices = Expectations( |
| 275 | + { |
| 276 | + ("xpu", 3): np.array([0.0816, 0.0518, 0.0445, 0.0594, 0.0739, 0.0534, 0.0805, 0.0457, 0.0765]), |
| 277 | + ("cuda", 7): np.array([0.1845, 0.1371, 0.1211, 0.2035, 0.1954, 0.1323, 0.1773, 0.1593, 0.1314]), |
| 278 | + ("cuda", 8): np.array([0.0816, 0.0518, 0.0445, 0.0594, 0.0739, 0.0534, 0.0805, 0.0457, 0.0765]), |
| 279 | + } |
| 280 | + ) |
| 281 | + expected_slice = expected_slices.get_expectation() |
272 | 282 |
|
273 | 283 | assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
274 | 284 |
|
|
0 commit comments