Skip to content

Commit 0326aa3

Browse files
committed
Remove redundant calls to dh.get_scalar_type()
1 parent 6e4564b commit 0326aa3

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

array_api_tests/test_statistical_functions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def test_max(x, data):
125125
s = scalar_type(x[idx])
126126
elements.append(s)
127127
expected = max(elements)
128-
assert_equals("max", dh.get_scalar_type(out.dtype), out_idx, max_, expected)
128+
assert_equals("max", scalar_type, out_idx, max_, expected)
129129

130130

131131
@given(
@@ -154,7 +154,7 @@ def test_mean(x, data):
154154
s = float(x[idx])
155155
elements.append(s)
156156
expected = sum(elements) / len(elements)
157-
assert_equals("mean", dh.get_scalar_type(out.dtype), out_idx, mean, expected)
157+
assert_equals("mean", float, out_idx, mean, expected)
158158

159159

160160
@given(
@@ -183,7 +183,7 @@ def test_min(x, data):
183183
s = scalar_type(x[idx])
184184
elements.append(s)
185185
expected = min(elements)
186-
assert_equals("min", dh.get_scalar_type(out.dtype), out_idx, min_, expected)
186+
assert_equals("min", scalar_type, out_idx, min_, expected)
187187

188188

189189
@given(
@@ -246,7 +246,7 @@ def test_prod(x, data):
246246
if dh.is_int_dtype(out.dtype):
247247
m, M = dh.dtype_ranges[out.dtype]
248248
assume(m <= expected <= M)
249-
assert_equals("prod", dh.get_scalar_type(out.dtype), out_idx, prod, expected)
249+
assert_equals("prod", scalar_type, out_idx, prod, expected)
250250

251251

252252
@given(
@@ -344,7 +344,7 @@ def test_sum(x, data):
344344
if dh.is_int_dtype(out.dtype):
345345
m, M = dh.dtype_ranges[out.dtype]
346346
assume(m <= expected <= M)
347-
assert_equals("sum", dh.get_scalar_type(out.dtype), out_idx, sum_, expected)
347+
assert_equals("sum", scalar_type, out_idx, sum_, expected)
348348

349349

350350
@given(

0 commit comments

Comments
 (0)