Skip to content

Commit 7fe9c96

Browse files
committed
Use pos-only args for assert helpers with kwargs
1 parent af9c062 commit 7fe9c96

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

array_api_tests/test_creation_functions.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from . import dtype_helpers as dh
1111
from . import pytest_helpers as ph
1212
from . import xps
13-
from .typing import Shape, DataType, Array
13+
from .typing import Shape, DataType, Array, Scalar
1414

1515

1616
def assert_default_float(func_name: str, dtype: DataType):
@@ -43,21 +43,25 @@ def assert_kw_dtype(func_name: str, kw_dtype: DataType, out_dtype: DataType):
4343
assert out_dtype == kw_dtype, msg
4444

4545

46-
def assert_shape(func_name: str, out_shape: Shape, expected: Union[int, Shape], **kw):
46+
def assert_shape(
47+
func_name: str, out_shape: Shape, expected: Union[int, Shape], /, **kw
48+
):
4749
f_kw = ", ".join(f"{k}={v}" for k, v in kw.items())
4850
msg = f"out.shape={out_shape}, but should be {expected} [{func_name}({f_kw})]"
4951
if isinstance(expected, int):
5052
expected = (expected,)
5153
assert out_shape == expected, msg
5254

5355

54-
def assert_fill(func_name: str, fill: float, dtype: DataType, out: Array, **kw):
56+
def assert_fill(
57+
func_name: str, fill_value: Scalar, dtype: DataType, out: Array, /, **kw
58+
):
5559
f_kw = ", ".join(f"{k}={v}" for k, v in kw.items())
56-
msg = f"out not filled with {fill} [{func_name}({f_kw})]\n" f"{out=}"
57-
if math.isnan(fill):
60+
msg = f"out not filled with {fill_value} [{func_name}({f_kw})]\n" f"{out=}"
61+
if math.isnan(fill_value):
5862
assert ah.all(ah.isnan(out)), msg
5963
else:
60-
assert ah.all(ah.equal(out, ah.asarray(fill, dtype=dtype))), msg
64+
assert ah.all(ah.equal(out, ah.asarray(fill_value, dtype=dtype))), msg
6165

6266

6367
# Testing xp.arange() requires bounding the start/stop/step arguments to only
@@ -375,7 +379,7 @@ def test_linspace(num, dtype, endpoint, data):
375379
# TODO: array assertions ala test_arange
376380

377381

378-
def make_one(dtype: DataType) -> Union[bool, float]:
382+
def make_one(dtype: DataType) -> Scalar:
379383
if dtype is None or dh.is_float_dtype(dtype):
380384
return 1.0
381385
elif dh.is_int_dtype(dtype):
@@ -411,7 +415,7 @@ def test_ones_like(x, kw):
411415
assert_fill("ones_like", make_one(dtype), dtype, out)
412416

413417

414-
def make_zero(dtype: DataType) -> Union[bool, float]:
418+
def make_zero(dtype: DataType) -> Scalar:
415419
if dtype is None or dh.is_float_dtype(dtype):
416420
return 0.0
417421
elif dh.is_int_dtype(dtype):

array_api_tests/typing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22

33
__all__ = [
44
"DataType",
5+
"Scalar",
56
"ScalarType",
67
"Array",
78
"Shape",
89
"Param",
910
]
1011

1112
DataType = Type[Any]
13+
Scalar = Union[bool, int, float]
1214
ScalarType = Union[Type[bool], Type[int], Type[float]]
1315
Array = Any
1416
Shape = Tuple[int, ...]

0 commit comments

Comments
 (0)