Skip to content

Commit fbe2fe5

Browse files
authored
enable consistency test cases on XPU, all passed (#11446)
Signed-off-by: Yao Matrix <matrix.yao@intel.com>
1 parent c865115 commit fbe2fe5

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

tests/pipelines/consistency_models/test_consistency_models.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
UNet2DModel,
1212
)
1313
from diffusers.utils.testing_utils import (
14+
Expectations,
15+
backend_empty_cache,
1416
enable_full_determinism,
1517
nightly,
1618
require_torch_2,
17-
require_torch_gpu,
19+
require_torch_accelerator,
1820
torch_device,
1921
)
2022
from diffusers.utils.torch_utils import randn_tensor
@@ -168,17 +170,17 @@ def test_consistency_model_pipeline_onestep_class_cond(self):
168170

169171

170172
@nightly
171-
@require_torch_gpu
173+
@require_torch_accelerator
172174
class ConsistencyModelPipelineSlowTests(unittest.TestCase):
173175
def setUp(self):
174176
super().setUp()
175177
gc.collect()
176-
torch.cuda.empty_cache()
178+
backend_empty_cache(torch_device)
177179

178180
def tearDown(self):
179181
super().tearDown()
180182
gc.collect()
181-
torch.cuda.empty_cache()
183+
backend_empty_cache(torch_device)
182184

183185
def get_inputs(self, seed=0, get_fixed_latents=False, device="cpu", dtype=torch.float32, shape=(1, 3, 64, 64)):
184186
generator = torch.manual_seed(seed)
@@ -264,11 +266,19 @@ def test_consistency_model_cd_multistep_flash_attn(self):
264266
# Ensure usage of flash attention in torch 2.0
265267
with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
266268
image = pipe(**inputs).images
269+
267270
assert image.shape == (1, 64, 64, 3)
268271

269272
image_slice = image[0, -3:, -3:, -1]
270273

271-
expected_slice = np.array([0.1845, 0.1371, 0.1211, 0.2035, 0.1954, 0.1323, 0.1773, 0.1593, 0.1314])
274+
expected_slices = Expectations(
275+
{
276+
("xpu", 3): np.array([0.0816, 0.0518, 0.0445, 0.0594, 0.0739, 0.0534, 0.0805, 0.0457, 0.0765]),
277+
("cuda", 7): np.array([0.1845, 0.1371, 0.1211, 0.2035, 0.1954, 0.1323, 0.1773, 0.1593, 0.1314]),
278+
("cuda", 8): np.array([0.0816, 0.0518, 0.0445, 0.0594, 0.0739, 0.0534, 0.0805, 0.0457, 0.0765]),
279+
}
280+
)
281+
expected_slice = expected_slices.get_expectation()
272282

273283
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
274284

0 commit comments

Comments
 (0)