Skip to content

Commit a9506a8

Browse files
committed
Use cmath where obvious
1 parent 84bd3ef commit a9506a8

File tree

6 files changed

+19
-14
lines changed

6 files changed

+19
-14
lines changed

array_api_tests/pytest_helpers.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import cmath
12
import math
23
from inspect import getfullargspec
34
from typing import Any, Dict, Optional, Sequence, Tuple, Union
@@ -345,12 +346,12 @@ def assert_scalar_equals(
345346
if type_ in [bool, int]:
346347
msg = f"{repr_name}={out}, but should be {expected} [{f_func}]"
347348
assert out == expected, msg
348-
elif math.isnan(expected):
349+
elif cmath.isnan(expected):
349350
msg = f"{repr_name}={out}, but should be {expected} [{f_func}]"
350-
assert math.isnan(out), msg
351+
assert cmath.isnan(out), msg
351352
else:
352353
msg = f"{repr_name}={out}, but should be roughly {expected} [{f_func}]"
353-
assert math.isclose(out, expected, rel_tol=0.25, abs_tol=1), msg
354+
assert cmath.isclose(out, expected, rel_tol=0.25, abs_tol=1), msg
354355

355356

356357
def assert_fill(
@@ -368,7 +369,7 @@ def assert_fill(
368369
369370
"""
370371
msg = f"out not filled with {fill_value} [{func_name}({fmt_kw(kw)})]\n{out=}"
371-
if math.isnan(fill_value):
372+
if cmath.isnan(fill_value):
372373
assert xp.all(xp.isnan(out)), msg
373374
else:
374375
assert xp.all(xp.equal(out, xp.asarray(fill_value, dtype=dtype))), msg

array_api_tests/test_array_object.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import cmath
12
import math
23
from itertools import product
34
from typing import List, Sequence, Tuple, Union, get_args
@@ -135,7 +136,7 @@ def test_setitem(shape, dtypes, data):
135136
f_res = sh.fmt_idx("x", key)
136137
if isinstance(value, get_args(Scalar)):
137138
msg = f"{f_res}={res[key]!r}, but should be {value=} [__setitem__()]"
138-
if math.isnan(value):
139+
if cmath.isnan(value):
139140
assert xp.isnan(res[key]), msg
140141
else:
141142
assert res[key] == value, msg

array_api_tests/test_creation_functions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import cmath
12
import math
23
from itertools import count
34
from typing import Iterator, NamedTuple, Union
@@ -247,8 +248,8 @@ def test_asarray_scalars(shape, data):
247248

248249

249250
def scalar_eq(s1: Scalar, s2: Scalar) -> bool:
250-
if math.isnan(s1):
251-
return math.isnan(s2)
251+
if cmath.isnan(s1):
252+
return cmath.isnan(s2)
252253
else:
253254
return s1 == s2
254255

array_api_tests/test_set_functions.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# TODO: disable if opted out, refactor things
2+
import cmath
23
import math
34
from collections import Counter, defaultdict
45

@@ -61,7 +62,7 @@ def test_unique_all(x):
6162

6263
for idx in sh.ndindex(out.indices.shape):
6364
val = scalar_type(out.values[idx])
64-
if math.isnan(val):
65+
if cmath.isnan(val):
6566
break
6667
i = int(out.indices[idx])
6768
expected = firsts[val]
@@ -88,7 +89,7 @@ def test_unique_all(x):
8889
for idx in sh.ndindex(out.values.shape):
8990
val = scalar_type(out.values[idx])
9091
count = int(out.counts[idx])
91-
if math.isnan(val):
92+
if cmath.isnan(val):
9293
nans += 1
9394
assert count == 1, (
9495
f"out.counts[{idx}]={count} for out.values[{idx}]={val}, "
@@ -225,7 +226,7 @@ def test_unique_values(x):
225226
nans = 0
226227
for idx in sh.ndindex(out.shape):
227228
val = scalar_type(out[idx])
228-
if math.isnan(val):
229+
if cmath.isnan(val):
229230
nans += 1
230231
else:
231232
assert val in distinct, f"out[{idx}]={val}, but {val} not in input array"

array_api_tests/test_sorting_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import math
1+
import cmath
22
from typing import Set
33

44
import pytest
@@ -26,7 +26,7 @@ def assert_scalar_in_set(
2626
**kw,
2727
):
2828
out_repr = "out" if idx == () else f"out[{idx}]"
29-
if math.isnan(out):
29+
if cmath.isnan(out):
3030
raise NotImplementedError()
3131
msg = f"{out_repr}={out}, but should be in {set_} [{func_name}({ph.fmt_kw(kw)})]"
3232
assert out in set_, msg

array_api_tests/test_statistical_functions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import cmath
12
import math
23
from typing import Optional
34

@@ -162,7 +163,7 @@ def test_prod(x, data):
162163
scalar_type = dh.get_scalar_type(out.dtype)
163164
for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)):
164165
prod = scalar_type(out[out_idx])
165-
assume(math.isfinite(prod))
166+
assume(cmath.isfinite(prod))
166167
elements = []
167168
for idx in indices:
168169
s = scalar_type(x[idx])
@@ -267,7 +268,7 @@ def test_sum(x, data):
267268
scalar_type = dh.get_scalar_type(out.dtype)
268269
for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)):
269270
sum_ = scalar_type(out[out_idx])
270-
assume(math.isfinite(sum_))
271+
assume(cmath.isfinite(sum_))
271272
elements = []
272273
for idx in indices:
273274
s = scalar_type(x[idx])

0 commit comments

Comments
 (0)