|
10 | 10 | from . import dtype_helpers as dh
|
11 | 11 | from . import pytest_helpers as ph
|
12 | 12 | from . import xps
|
13 |
| -from .typing import Shape, DataType, Array |
| 13 | +from .typing import Shape, DataType, Array, Scalar |
14 | 14 |
|
15 | 15 |
|
16 | 16 | def assert_default_float(func_name: str, dtype: DataType):
|
@@ -43,21 +43,25 @@ def assert_kw_dtype(func_name: str, kw_dtype: DataType, out_dtype: DataType):
|
43 | 43 | assert out_dtype == kw_dtype, msg
|
44 | 44 |
|
45 | 45 |
|
46 |
| -def assert_shape(func_name: str, out_shape: Shape, expected: Union[int, Shape], **kw): |
| 46 | +def assert_shape( |
| 47 | + func_name: str, out_shape: Shape, expected: Union[int, Shape], /, **kw |
| 48 | +): |
47 | 49 | f_kw = ", ".join(f"{k}={v}" for k, v in kw.items())
|
48 | 50 | msg = f"out.shape={out_shape}, but should be {expected} [{func_name}({f_kw})]"
|
49 | 51 | if isinstance(expected, int):
|
50 | 52 | expected = (expected,)
|
51 | 53 | assert out_shape == expected, msg
|
52 | 54 |
|
53 | 55 |
|
54 |
| -def assert_fill(func_name: str, fill: float, dtype: DataType, out: Array, **kw): |
| 56 | +def assert_fill( |
| 57 | + func_name: str, fill_value: Scalar, dtype: DataType, out: Array, /, **kw |
| 58 | +): |
55 | 59 | f_kw = ", ".join(f"{k}={v}" for k, v in kw.items())
|
56 |
| - msg = f"out not filled with {fill} [{func_name}({f_kw})]\n" f"{out=}" |
57 |
| - if math.isnan(fill): |
| 60 | + msg = f"out not filled with {fill_value} [{func_name}({f_kw})]\n" f"{out=}" |
| 61 | + if math.isnan(fill_value): |
58 | 62 | assert ah.all(ah.isnan(out)), msg
|
59 | 63 | else:
|
60 |
| - assert ah.all(ah.equal(out, ah.asarray(fill, dtype=dtype))), msg |
| 64 | + assert ah.all(ah.equal(out, ah.asarray(fill_value, dtype=dtype))), msg |
61 | 65 |
|
62 | 66 |
|
63 | 67 | # Testing xp.arange() requires bounding the start/stop/step arguments to only
|
@@ -375,7 +379,7 @@ def test_linspace(num, dtype, endpoint, data):
|
375 | 379 | # TODO: array assertions ala test_arange
|
376 | 380 |
|
377 | 381 |
|
378 |
| -def make_one(dtype: DataType) -> Union[bool, float]: |
| 382 | +def make_one(dtype: DataType) -> Scalar: |
379 | 383 | if dtype is None or dh.is_float_dtype(dtype):
|
380 | 384 | return 1.0
|
381 | 385 | elif dh.is_int_dtype(dtype):
|
@@ -411,7 +415,7 @@ def test_ones_like(x, kw):
|
411 | 415 | assert_fill("ones_like", make_one(dtype), dtype, out)
|
412 | 416 |
|
413 | 417 |
|
414 |
| -def make_zero(dtype: DataType) -> Union[bool, float]: |
| 418 | +def make_zero(dtype: DataType) -> Scalar: |
415 | 419 | if dtype is None or dh.is_float_dtype(dtype):
|
416 | 420 | return 0.0
|
417 | 421 | elif dh.is_int_dtype(dtype):
|
|
0 commit comments