Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 52e5d0a

Browse files
committed
Add utility for filtering out skpped tests in large paremtrization groups
ghstack-source-id: d99192c Pull Request resolved: #303
1 parent 7e7fbec commit 52e5d0a

File tree

3 files changed

+56
-28
lines changed

3 files changed

+56
-28
lines changed

float8_experimental/float8_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def cast_x_to_float8(
300300
if torch.is_autocast_enabled():
301301
# For now, hardcode to GPU's autocast dtype
302302
# if we need CPU support in the future, we can add it
303-
autocast_dtype = torch.get_autocast_gpu_dtype()
303+
autocast_dtype = torch.get_autocast_dtype("cuda")
304304
x = x.to(autocast_dtype)
305305

306306
if self.scaling_type_x is TensorScalingType.DELAYED:

float8_experimental/float8_linear_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
227227
fp8_layers = get_float8_layers(model)
228228

229229
if len(fp8_layers) == 0:
230-
log.warn(
230+
log.warning(
231231
"Calling sync_float8_amax_and_scale_history on a module with no Float8Linear layers"
232232
)
233233
return

test/test_base.py

Lines changed: 54 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import re
1111
import unittest
1212
import warnings
13+
from itertools import product
14+
from typing import Any, Callable, Dict, List, Optional, Tuple
1315

1416
import pytest
1517

@@ -50,6 +52,42 @@
5052
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
5153

5254

55+
def filtered_parametrize(
56+
param_list: List[Tuple[str, List[Any]]],
57+
filter_func: Optional[Callable[[Dict[str, Any]], bool]] = None,
58+
):
59+
"""
60+
A decorator that works like pytest.mark.parametrize but filters out
61+
unwanted parameter combinations.
62+
63+
Args:
64+
param_list: A list of tuples, each containing (arg_name, [arg_values])
65+
filter_func: A function that takes a dictionary of parameter names and values,
66+
and returns True for valid combinations, False otherwise
67+
68+
"""
69+
70+
def decorator(func):
71+
arg_names = [param[0] for param in param_list]
72+
arg_values = [param[1] for param in param_list]
73+
74+
all_combinations = product(*arg_values)
75+
if filter_func:
76+
valid_combinations = [
77+
combo
78+
for combo in all_combinations
79+
if filter_func(dict(zip(arg_names, combo)))
80+
]
81+
else:
82+
valid_combinations = list(all_combinations)
83+
84+
return pytest.mark.parametrize(
85+
argnames=arg_names, argvalues=valid_combinations
86+
)(func)
87+
88+
return decorator
89+
90+
5391
def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
5492
assert torch.all(a._data == b._data).item(), "scales are not identical"
5593
assert torch.all(a._data == b._data).item(), "data is not identical"
@@ -243,48 +281,38 @@ def test_linear(
243281
scaling_type_x: TensorScalingType,
244282
scaling_type_w: TensorScalingType,
245283
scaling_type_dL_dY: TensorScalingType,
246-
linear_dtype: torch.dtype,
247-
linear_bias: bool,
248284
):
249-
if not emulate:
250-
if not torch.cuda.is_available():
251-
warnings.warn("CUDA not available")
252-
pytest.skip()
253-
elif torch.cuda.get_device_capability() < (9, 0):
254-
warnings.warn(
255-
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
256-
)
257-
pytest.skip()
258-
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
259-
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
285+
x = torch.randn(*x_shape, device="cuda")
286+
m_ref = nn.Linear(16, 32, bias=False, device="cuda")
260287
self._test_linear_impl(
261288
x,
262289
m_ref,
290+
linear_type,
263291
emulate,
264292
scaling_type_x,
265293
scaling_type_w,
266294
scaling_type_dL_dY,
267295
)
268-
269-
@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
270-
@pytest.mark.parametrize(
271-
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
296+
297+
@filtered_parametrize(
298+
[
299+
("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]),
300+
("emulate", [True, False] if is_H100 else [True]),
301+
("scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]),
302+
("scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]),
303+
(
304+
"scaling_type_dL_dY",
305+
[TensorScalingType.DELAYED, TensorScalingType.DYNAMIC],
306+
),
307+
("linear_dtype", [torch.float16, torch.bfloat16, torch.float32]),
308+
],
272309
)
273310
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
274311
def test_autocast_outputs(
275312
self,
276313
emulate: bool,
277314
linear_dtype: torch.dtype,
278315
):
279-
if not emulate:
280-
if not torch.cuda.is_available():
281-
warnings.warn("CUDA not available")
282-
pytest.skip()
283-
elif torch.cuda.get_device_capability() < (9, 0):
284-
warnings.warn(
285-
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
286-
)
287-
pytest.skip()
288316

289317
m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
290318
kwargs = {

0 commit comments

Comments
 (0)