Skip to content

Commit 0685e3c

Browse files
committed
dh.get_scalar_type() helper
1 parent bd4c155 commit 0685e3c

File tree

2 files changed

+22
-46
lines changed

2 files changed

+22
-46
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
'dtype_to_scalars',
2020
'is_int_dtype',
2121
'is_float_dtype',
22+
'get_scalar_type',
2223
'dtype_ranges',
2324
'default_int',
2425
'default_float',
@@ -75,6 +76,15 @@ def is_float_dtype(dtype):
7576
return dtype in float_dtypes
7677

7778

79+
def get_scalar_type(dtype: DataType) -> ScalarType:
80+
if is_int_dtype(dtype):
81+
return int
82+
elif is_float_dtype(dtype):
83+
return float
84+
else:
85+
return bool
86+
87+
7888
class MinMax(NamedTuple):
7989
min: int
8090
max: int

array_api_tests/test_elementwise_functions.py

Lines changed: 12 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -709,19 +709,13 @@ def test_equal(
709709
_left = ah.asarray(_left, dtype=promoted_dtype)
710710
_right = ah.asarray(_right, dtype=promoted_dtype)
711711

712-
if dh.is_int_dtype(promoted_dtype):
713-
scalar_func = int
714-
elif dh.is_float_dtype(promoted_dtype):
715-
scalar_func = float
716-
else:
717-
scalar_func = bool
718-
712+
scalar_type = dh.get_scalar_type(promoted_dtype)
719713
for idx in ah.ndindex(shape):
720714
x1_idx = _left[idx]
721715
x2_idx = _right[idx]
722716
out_idx = out[idx]
723717
assert out_idx.shape == x1_idx.shape == x2_idx.shape # sanity check
724-
assert bool(out_idx) == (scalar_func(x1_idx) == scalar_func(x2_idx))
718+
assert bool(out_idx) == (scalar_type(x1_idx) == scalar_type(x2_idx))
725719

726720

727721
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
@@ -840,18 +834,13 @@ def test_greater(
840834
_left = ah.asarray(_left, dtype=promoted_dtype)
841835
_right = ah.asarray(_right, dtype=promoted_dtype)
842836

843-
if dh.is_int_dtype(promoted_dtype):
844-
scalar_func = int
845-
elif dh.is_float_dtype(promoted_dtype):
846-
scalar_func = float
847-
else:
848-
scalar_func = bool
837+
scalar_type = dh.get_scalar_type(promoted_dtype)
849838
for idx in ah.ndindex(shape):
850839
out_idx = out[idx]
851840
x1_idx = _left[idx]
852841
x2_idx = _right[idx]
853842
assert out_idx.shape == x1_idx.shape == x2_idx.shape # sanity check
854-
assert bool(out_idx) == (scalar_func(x1_idx) > scalar_func(x2_idx))
843+
assert bool(out_idx) == (scalar_type(x1_idx) > scalar_type(x2_idx))
855844

856845

857846
@pytest.mark.parametrize(
@@ -886,18 +875,13 @@ def test_greater_equal(
886875
_left = ah.asarray(_left, dtype=promoted_dtype)
887876
_right = ah.asarray(_right, dtype=promoted_dtype)
888877

889-
if dh.is_int_dtype(promoted_dtype):
890-
scalar_func = int
891-
elif dh.is_float_dtype(promoted_dtype):
892-
scalar_func = float
893-
else:
894-
scalar_func = bool
878+
scalar_type = dh.get_scalar_type(promoted_dtype)
895879
for idx in ah.ndindex(shape):
896880
out_idx = out[idx]
897881
x1_idx = _left[idx]
898882
x2_idx = _right[idx]
899883
assert out_idx.shape == x1_idx.shape == x2_idx.shape # sanity check
900-
assert bool(out_idx) == (scalar_func(x1_idx) >= scalar_func(x2_idx))
884+
assert bool(out_idx) == (scalar_type(x1_idx) >= scalar_type(x2_idx))
901885

902886

903887
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
@@ -983,19 +967,13 @@ def test_less(
983967
_left = ah.asarray(_left, dtype=promoted_dtype)
984968
_right = ah.asarray(_right, dtype=promoted_dtype)
985969

986-
if dh.is_int_dtype(promoted_dtype):
987-
scalar_func = int
988-
elif dh.is_float_dtype(promoted_dtype):
989-
scalar_func = float
990-
else:
991-
scalar_func = bool
992-
970+
scalar_type = dh.get_scalar_type(promoted_dtype)
993971
for idx in ah.ndindex(shape):
994972
x1_idx = _left[idx]
995973
x2_idx = _right[idx]
996974
out_idx = out[idx]
997975
assert out_idx.shape == x1_idx.shape == x2_idx.shape # sanity check
998-
assert bool(out_idx) == (scalar_func(x1_idx) < scalar_func(x2_idx))
976+
assert bool(out_idx) == (scalar_type(x1_idx) < scalar_type(x2_idx))
999977

1000978

1001979
@pytest.mark.parametrize(
@@ -1030,19 +1008,13 @@ def test_less_equal(
10301008
_left = ah.asarray(_left, dtype=promoted_dtype)
10311009
_right = ah.asarray(_right, dtype=promoted_dtype)
10321010

1033-
if dh.is_int_dtype(promoted_dtype):
1034-
scalar_func = int
1035-
elif dh.is_float_dtype(promoted_dtype):
1036-
scalar_func = float
1037-
else:
1038-
scalar_func = bool
1039-
1011+
scalar_type = dh.get_scalar_type(promoted_dtype)
10401012
for idx in ah.ndindex(shape):
10411013
x1_idx = _left[idx]
10421014
x2_idx = _right[idx]
10431015
out_idx = out[idx]
10441016
assert out_idx.shape == x1_idx.shape == x2_idx.shape # sanity check
1045-
assert bool(out_idx) == (scalar_func(x1_idx) <= scalar_func(x2_idx))
1017+
assert bool(out_idx) == (scalar_type(x1_idx) <= scalar_type(x2_idx))
10461018

10471019

10481020
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
@@ -1241,19 +1213,13 @@ def test_not_equal(
12411213
_left = ah.asarray(_left, dtype=promoted_dtype)
12421214
_right = ah.asarray(_right, dtype=promoted_dtype)
12431215

1244-
if dh.is_int_dtype(promoted_dtype):
1245-
scalar_func = int
1246-
elif dh.is_float_dtype(promoted_dtype):
1247-
scalar_func = float
1248-
else:
1249-
scalar_func = bool
1250-
1216+
scalar_type = dh.get_scalar_type(promoted_dtype)
12511217
for idx in ah.ndindex(shape):
12521218
out_idx = out[idx]
12531219
x1_idx = _left[idx]
12541220
x2_idx = _right[idx]
12551221
assert out_idx.shape == x1_idx.shape == x2_idx.shape # sanity check
1256-
assert bool(out_idx) == (scalar_func(x1_idx) != scalar_func(x2_idx))
1222+
assert bool(out_idx) == (scalar_type(x1_idx) != scalar_type(x2_idx))
12571223

12581224

12591225
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)