Skip to content

Commit 7c68b29

Browse files
Added numpy-like posinf and neginf
1 parent 64e60ed commit 7c68b29

File tree

2 files changed

+44
-12
lines changed

2 files changed

+44
-12
lines changed

pytensor/tensor/math.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,22 @@ def largest(*args):
811811
return max(stack(args), axis=0)
812812

813813

814+
def isposinf(x):
815+
"""
816+
Return if the input variable has positive infinity element
817+
818+
"""
819+
return eq(x, np.inf)
820+
821+
822+
def isneginf(x):
823+
"""
824+
Return if the input variable has negative infinity element
825+
826+
"""
827+
return eq(x, -np.inf)
828+
829+
814830
@scalar_elemwise
815831
def lt(a, b):
816832
"""a < b"""
@@ -3077,11 +3093,8 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
30773093
"""
30783094
# Replace NaN's with nan keyword
30793095
is_nan = isnan(x)
3080-
is_pos_inf = eq(x, np.inf)
3081-
is_neg_inf = eq(x, -np.inf)
3082-
3083-
if not any(is_nan) and not any(is_pos_inf) and not any(is_neg_inf):
3084-
return
3096+
is_pos_inf = isposinf(x)
3097+
is_neg_inf = isneginf(x)
30853098

30863099
x = switch(is_nan, nan, x)
30873100

@@ -3140,6 +3153,8 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
31403153
"not_equal",
31413154
"isnan",
31423155
"isinf",
3156+
"isposinf",
3157+
"isneginf",
31433158
"allclose",
31443159
"isclose",
31453160
"and_",

tests/tensor/test_math.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@
7979
isinf,
8080
isnan,
8181
isnan_,
82+
isneginf,
83+
isposinf,
8284
log,
8385
log1mexp,
8486
log1p,
@@ -3644,6 +3646,26 @@ def test_grad_n_undefined(self):
36443646
grad(polygamma(n, 0.5), wrt=n)
36453647

36463648

3649+
def test_infs():
3650+
x = tensor(shape=(7,))
3651+
3652+
f_pos = function([x], isposinf(x))
3653+
f_neg = function([x], isneginf(x))
3654+
3655+
y = np.array([1, np.inf, 2, np.inf, -np.inf, -np.inf, 4]).astype(x.dtype)
3656+
out_pos = f_pos(y)
3657+
out_neg = f_neg(y)
3658+
3659+
np.testing.assert_allclose(
3660+
out_pos,
3661+
[0, 1, 0, 1, 0, 0, 0],
3662+
)
3663+
np.testing.assert_allclose(
3664+
out_neg,
3665+
[0, 0, 0, 0, 1, 1, 0],
3666+
)
3667+
3668+
36473669
@pytest.mark.parametrize(
36483670
["nan", "posinf", "neginf"],
36493671
[(0, None, None), (0, 0, 0), (0, None, 1000), (3, 1, -1)],
@@ -3653,14 +3675,9 @@ def test_nan_to_num(nan, posinf, neginf):
36533675

36543676
out = nan_to_num(x, nan, posinf, neginf)
36553677

3656-
f = function(
3657-
[x],
3658-
nan_to_num(x, nan, posinf, neginf),
3659-
on_unused_input="warn",
3660-
allow_input_downcast=True,
3661-
)
3678+
f = function([x], out)
36623679

3663-
y = np.array([1, 2, np.nan, np.inf, -np.inf, 3, 4])
3680+
y = np.array([1, 2, np.nan, np.inf, -np.inf, 3, 4]).astype(x.dtype)
36643681
out = f(y)
36653682

36663683
posinf = np.finfo(x.real.dtype).max if posinf is None else posinf

0 commit comments

Comments
 (0)