|
10 | 10 | import re
|
11 | 11 | import unittest
|
12 | 12 | import warnings
|
| 13 | +from itertools import product |
| 14 | +from typing import Any, Callable, Dict, List, Optional, Tuple |
13 | 15 |
|
14 | 16 | import pytest
|
15 | 17 |
|
|
50 | 52 | is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
|
51 | 53 |
|
52 | 54 |
|
| 55 | +def filtered_parametrize( |
| 56 | + param_list: List[Tuple[str, List[Any]]], |
| 57 | + filter_func: Optional[Callable[[Dict[str, Any]], bool]] = None, |
| 58 | +): |
| 59 | + """ |
| 60 | + A decorator that works like pytest.mark.parametrize but filters out |
| 61 | + unwanted parameter combinations. |
| 62 | +
|
| 63 | + Args: |
| 64 | + param_list: A list of tuples, each containing (arg_name, [arg_values]) |
| 65 | + filter_func: A function that takes a dictionary of parameter names and values, |
| 66 | + and returns True for valid combinations, False otherwise |
| 67 | +
|
| 68 | + """ |
| 69 | + |
| 70 | + def decorator(func): |
| 71 | + arg_names = [param[0] for param in param_list] |
| 72 | + arg_values = [param[1] for param in param_list] |
| 73 | + |
| 74 | + all_combinations = product(*arg_values) |
| 75 | + if filter_func: |
| 76 | + valid_combinations = [ |
| 77 | + combo |
| 78 | + for combo in all_combinations |
| 79 | + if filter_func(dict(zip(arg_names, combo))) |
| 80 | + ] |
| 81 | + else: |
| 82 | + valid_combinations = list(all_combinations) |
| 83 | + |
| 84 | + return pytest.mark.parametrize( |
| 85 | + argnames=arg_names, argvalues=valid_combinations |
| 86 | + )(func) |
| 87 | + |
| 88 | + return decorator |
| 89 | + |
| 90 | + |
53 | 91 | def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
|
54 | 92 | assert torch.all(a._data == b._data).item(), "scales are not identical"
|
55 | 93 | assert torch.all(a._data == b._data).item(), "data is not identical"
|
@@ -243,48 +281,38 @@ def test_linear(
|
243 | 281 | scaling_type_x: TensorScalingType,
|
244 | 282 | scaling_type_w: TensorScalingType,
|
245 | 283 | scaling_type_dL_dY: TensorScalingType,
|
246 |
| - linear_dtype: torch.dtype, |
247 |
| - linear_bias: bool, |
248 | 284 | ):
|
249 |
| - if not emulate: |
250 |
| - if not torch.cuda.is_available(): |
251 |
| - warnings.warn("CUDA not available") |
252 |
| - pytest.skip() |
253 |
| - elif torch.cuda.get_device_capability() < (9, 0): |
254 |
| - warnings.warn( |
255 |
| - f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" |
256 |
| - ) |
257 |
| - pytest.skip() |
258 |
| - x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) |
259 |
| - m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) |
| 285 | + x = torch.randn(*x_shape, device="cuda") |
| 286 | + m_ref = nn.Linear(16, 32, bias=False, device="cuda") |
260 | 287 | self._test_linear_impl(
|
261 | 288 | x,
|
262 | 289 | m_ref,
|
| 290 | + linear_type, |
263 | 291 | emulate,
|
264 | 292 | scaling_type_x,
|
265 | 293 | scaling_type_w,
|
266 | 294 | scaling_type_dL_dY,
|
267 | 295 | )
|
268 |
| - |
269 |
| - @pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True]) |
270 |
| - @pytest.mark.parametrize( |
271 |
| - "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] |
| 296 | + |
| 297 | + @filtered_parametrize( |
| 298 | + [ |
| 299 | + ("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]), |
| 300 | + ("emulate", [True, False] if is_H100 else [True]), |
| 301 | + ("scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]), |
| 302 | + ("scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]), |
| 303 | + ( |
| 304 | + "scaling_type_dL_dY", |
| 305 | + [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC], |
| 306 | + ), |
| 307 | + ("linear_dtype", [torch.float16, torch.bfloat16, torch.float32]), |
| 308 | + ], |
272 | 309 | )
|
273 | 310 | @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
274 | 311 | def test_autocast_outputs(
|
275 | 312 | self,
|
276 | 313 | emulate: bool,
|
277 | 314 | linear_dtype: torch.dtype,
|
278 | 315 | ):
|
279 |
| - if not emulate: |
280 |
| - if not torch.cuda.is_available(): |
281 |
| - warnings.warn("CUDA not available") |
282 |
| - pytest.skip() |
283 |
| - elif torch.cuda.get_device_capability() < (9, 0): |
284 |
| - warnings.warn( |
285 |
| - f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" |
286 |
| - ) |
287 |
| - pytest.skip() |
288 | 316 |
|
289 | 317 | m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
|
290 | 318 | kwargs = {
|
|
0 commit comments