diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py index e255cb510c42..7c7cecdfb014 100644 --- a/tests/pipelines/consistency_models/test_consistency_models.py +++ b/tests/pipelines/consistency_models/test_consistency_models.py @@ -11,10 +11,12 @@ UNet2DModel, ) from diffusers.utils.testing_utils import ( + Expectations, + backend_empty_cache, enable_full_determinism, nightly, require_torch_2, - require_torch_gpu, + require_torch_accelerator, torch_device, ) from diffusers.utils.torch_utils import randn_tensor @@ -168,17 +170,17 @@ def test_consistency_model_pipeline_onestep_class_cond(self): @nightly -@require_torch_gpu +@require_torch_accelerator class ConsistencyModelPipelineSlowTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def get_inputs(self, seed=0, get_fixed_latents=False, device="cpu", dtype=torch.float32, shape=(1, 3, 64, 64)): generator = torch.manual_seed(seed) @@ -264,11 +266,19 @@ def test_consistency_model_cd_multistep_flash_attn(self): # Ensure usage of flash attention in torch 2.0 with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): image = pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) image_slice = image[0, -3:, -3:, -1] - expected_slice = np.array([0.1845, 0.1371, 0.1211, 0.2035, 0.1954, 0.1323, 0.1773, 0.1593, 0.1314]) + expected_slices = Expectations( + { + ("xpu", 3): np.array([0.0816, 0.0518, 0.0445, 0.0594, 0.0739, 0.0534, 0.0805, 0.0457, 0.0765]), + ("cuda", 7): np.array([0.1845, 0.1371, 0.1211, 0.2035, 0.1954, 0.1323, 0.1773, 0.1593, 0.1314]), + ("cuda", 8): np.array([0.0816, 0.0518, 0.0445, 0.0594, 0.0739, 0.0534, 0.0805, 0.0457, 0.0765]), + } + ) + expected_slice = expected_slices.get_expectation() assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3