Skip to content

Commit 7618a6d

Browse files
Added numpy-like posinf and neginf
1 parent 64e60ed commit 7618a6d

File tree

2 files changed

+55
-5
lines changed

2 files changed

+55
-5
lines changed

pytensor/tensor/math.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,22 @@ def largest(*args):
811811
return max(stack(args), axis=0)
812812

813813

814+
def isposinf(x):
815+
"""
816+
Return if the input variable has positive infinity element
817+
818+
"""
819+
return eq(x, np.inf)
820+
821+
822+
def isneginf(x):
823+
"""
824+
Return if the input variable has negative infinity element
825+
826+
"""
827+
return eq(x, -np.inf)
828+
829+
814830
@scalar_elemwise
815831
def lt(a, b):
816832
"""a < b"""
@@ -3077,11 +3093,11 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
30773093
"""
30783094
# Replace NaN's with nan keyword
30793095
is_nan = isnan(x)
3080-
is_pos_inf = eq(x, np.inf)
3081-
is_neg_inf = eq(x, -np.inf)
3096+
is_pos_inf = isposinf(x)
3097+
is_neg_inf = isneginf(x)
30823098

3083-
if not any(is_nan) and not any(is_pos_inf) and not any(is_neg_inf):
3084-
return
3099+
# if not any(is_nan) and not any(is_pos_inf) and not any(is_neg_inf):
3100+
# return
30853101

30863102
x = switch(is_nan, nan, x)
30873103

@@ -3140,6 +3156,8 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
31403156
"not_equal",
31413157
"isnan",
31423158
"isinf",
3159+
"isposinf",
3160+
"isneginf",
31433161
"allclose",
31443162
"isclose",
31453163
"and_",

tests/tensor/test_math.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@
7979
isinf,
8080
isnan,
8181
isnan_,
82+
isneginf,
83+
isposinf,
8284
log,
8385
log1mexp,
8486
log1p,
@@ -3644,6 +3646,36 @@ def test_grad_n_undefined(self):
36443646
grad(polygamma(n, 0.5), wrt=n)
36453647

36463648

3649+
def test_infs():
3650+
x = tensor(shape=(7,))
3651+
3652+
f_pos = function(
3653+
[x],
3654+
isposinf(x),
3655+
on_unused_input="warn",
3656+
allow_input_downcast=True,
3657+
)
3658+
f_neg = function(
3659+
[x],
3660+
isneginf(x),
3661+
on_unused_input="warn",
3662+
allow_input_downcast=True,
3663+
)
3664+
3665+
y = np.array([1, np.inf, 2, np.inf, -np.inf, -np.inf, 4])
3666+
out_pos = f_pos(y)
3667+
out_neg = f_neg(y)
3668+
3669+
np.testing.assert_allclose(
3670+
out_pos,
3671+
[0, 1, 0, 1, 0, 0, 0],
3672+
)
3673+
np.testing.assert_allclose(
3674+
out_neg,
3675+
[0, 0, 0, 0, 1, 1, 0],
3676+
)
3677+
3678+
36473679
@pytest.mark.parametrize(
36483680
["nan", "posinf", "neginf"],
36493681
[(0, None, None), (0, 0, 0), (0, None, 1000), (3, 1, -1)],
@@ -3655,7 +3687,7 @@ def test_nan_to_num(nan, posinf, neginf):
36553687

36563688
f = function(
36573689
[x],
3658-
nan_to_num(x, nan, posinf, neginf),
3690+
out,
36593691
on_unused_input="warn",
36603692
allow_input_downcast=True,
36613693
)

0 commit comments

Comments
 (0)