Skip to content

Commit 7b100ce

Browse files
authored
[Tests] conditionally check fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory (#10669)
* conditionally check if compute capability is met. * log info. * fix condition. * updates * updates * updates * updates
1 parent c4d4ac2 commit 7b100ce

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

src/diffusers/utils/torch_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,13 @@ def apply_freeu(
149149
res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s2"])
150150

151151
return hidden_states, res_hidden_states
152+
153+
154+
def get_torch_cuda_device_capability():
155+
if torch.cuda.is_available():
156+
device = torch.device("cuda")
157+
compute_capability = torch.cuda.get_device_capability(device)
158+
compute_capability = f"{compute_capability[0]}.{compute_capability[1]}"
159+
return float(compute_capability)
160+
else:
161+
return None

tests/models/test_modeling_common.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
torch_all_close,
6969
torch_device,
7070
)
71+
from diffusers.utils.torch_utils import get_torch_cuda_device_capability
7172

7273
from ..others.test_utils import TOKEN, USER, is_staging_test
7374

@@ -1384,6 +1385,7 @@ def test_layerwise_casting(storage_dtype, compute_dtype):
13841385
@require_torch_gpu
13851386
def test_layerwise_casting_memory(self):
13861387
MB_TOLERANCE = 0.2
1388+
LEAST_COMPUTE_CAPABILITY = 8.0
13871389

13881390
def reset_memory_stats():
13891391
gc.collect()
@@ -1412,10 +1414,12 @@ def get_memory_usage(storage_dtype, compute_dtype):
14121414
torch.float8_e4m3fn, torch.bfloat16
14131415
)
14141416

1417+
compute_capability = get_torch_cuda_device_capability()
14151418
self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint)
1416-
# NOTE: the following assertion will fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
1417-
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes.
1418-
self.assertTrue(fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory)
1419+
# NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
1420+
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it.
1421+
if compute_capability and compute_capability >= LEAST_COMPUTE_CAPABILITY:
1422+
self.assertTrue(fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory)
14191423
# On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few
14201424
# bytes. This only happens for some models, so we allow a small tolerance.
14211425
# For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32.

0 commit comments

Comments
 (0)