|
68 | 68 | torch_all_close,
|
69 | 69 | torch_device,
|
70 | 70 | )
|
| 71 | +from diffusers.utils.torch_utils import get_torch_cuda_device_capability |
71 | 72 |
|
72 | 73 | from ..others.test_utils import TOKEN, USER, is_staging_test
|
73 | 74 |
|
@@ -1384,6 +1385,7 @@ def test_layerwise_casting(storage_dtype, compute_dtype):
|
1384 | 1385 | @require_torch_gpu
|
1385 | 1386 | def test_layerwise_casting_memory(self):
|
1386 | 1387 | MB_TOLERANCE = 0.2
|
| 1388 | + LEAST_COMPUTE_CAPABILITY = 8.0 |
1387 | 1389 |
|
1388 | 1390 | def reset_memory_stats():
|
1389 | 1391 | gc.collect()
|
@@ -1412,10 +1414,12 @@ def get_memory_usage(storage_dtype, compute_dtype):
|
1412 | 1414 | torch.float8_e4m3fn, torch.bfloat16
|
1413 | 1415 | )
|
1414 | 1416 |
|
| 1417 | + compute_capability = get_torch_cuda_device_capability() |
1415 | 1418 | self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint)
|
1416 |
| - # NOTE: the following assertion will fail on our CI (running Tesla T4) due to bf16 using more memory than fp32. |
1417 |
| - # On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. |
1418 |
| - self.assertTrue(fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory) |
| 1419 | + # NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32. |
| 1420 | + # On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it. |
| 1421 | + if compute_capability and compute_capability >= LEAST_COMPUTE_CAPABILITY: |
| 1422 | + self.assertTrue(fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory) |
1419 | 1423 | # On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few
|
1420 | 1424 | # bytes. This only happens for some models, so we allow a small tolerance.
|
1421 | 1425 | # For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32.
|
|
0 commit comments