diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 4ba6f7c25eac..7a524e76f16e 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -1186,6 +1186,13 @@ def _is_torch_fp64_available(device): "mps": 0, "default": 0, } + BACKEND_SYNCHRONIZE = { + "cuda": torch.cuda.synchronize, + "xpu": getattr(torch.xpu, "synchronize", None), + "cpu": None, + "mps": None, + "default": None, + } # This dispatches a defined function according to the accelerator from the function definitions. @@ -1208,6 +1215,10 @@ def backend_manual_seed(device: str, seed: int): return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed) +def backend_synchronize(device: str): + return _device_agnostic_dispatch(device, BACKEND_SYNCHRONIZE) + + def backend_empty_cache(device: str): return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index f82a2407f333..4faa52cf7e8a 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -59,6 +59,9 @@ from diffusers.utils.testing_utils import ( CaptureLogger, backend_empty_cache, + backend_max_memory_allocated, + backend_reset_peak_memory_stats, + backend_synchronize, floats_tensor, get_python_version, is_torch_compile, @@ -68,7 +71,6 @@ require_torch_2, require_torch_accelerator, require_torch_accelerator_with_training, - require_torch_gpu, require_torch_multi_accelerator, run_test_in_subprocess, slow, @@ -341,7 +343,7 @@ def test_weight_overwrite(self): assert model.config.in_channels == 9 - @require_torch_gpu + @require_torch_accelerator def test_keep_modules_in_fp32(self): r""" A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32 when we load the model in fp16/bf16 @@ -1480,16 +1482,16 @@ def test_layerwise_casting(storage_dtype, compute_dtype): test_layerwise_casting(torch.float8_e5m2, torch.float32) test_layerwise_casting(torch.float8_e4m3fn, torch.bfloat16) - @require_torch_gpu + @require_torch_accelerator def test_layerwise_casting_memory(self): MB_TOLERANCE = 0.2 LEAST_COMPUTE_CAPABILITY = 8.0 def reset_memory_stats(): gc.collect() - torch.cuda.synchronize() - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() + backend_synchronize(torch_device) + backend_empty_cache(torch_device) + backend_reset_peak_memory_stats(torch_device) def get_memory_usage(storage_dtype, compute_dtype): torch.manual_seed(0) @@ -1502,7 +1504,7 @@ def get_memory_usage(storage_dtype, compute_dtype): reset_memory_stats() model(**inputs_dict) model_memory_footprint = model.get_memory_footprint() - peak_inference_memory_allocated_mb = torch.cuda.max_memory_allocated() / 1024**2 + peak_inference_memory_allocated_mb = backend_max_memory_allocated(torch_device) / 1024**2 return model_memory_footprint, peak_inference_memory_allocated_mb @@ -1512,7 +1514,7 @@ def get_memory_usage(storage_dtype, compute_dtype): torch.float8_e4m3fn, torch.bfloat16 ) - compute_capability = get_torch_cuda_device_capability() + compute_capability = get_torch_cuda_device_capability() if torch_device == "cuda" else None self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint) # NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32. # On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it. @@ -1527,7 +1529,7 @@ def get_memory_usage(storage_dtype, compute_dtype): ) @parameterized.expand([False, True]) - @require_torch_gpu + @require_torch_accelerator def test_group_offloading(self, record_stream): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() torch.manual_seed(0)