Skip to content

Commit ca10298

Browse files
Add nan_to_num helper (#796)
As well as numpy-like posinf and neginf
1 parent 05d376f commit ca10298

File tree

2 files changed

+121
-0
lines changed

2 files changed

+121
-0
lines changed

pytensor/tensor/math.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,22 @@ def largest(*args):
681681
return max(stack(args), axis=0)
682682

683683

684+
def isposinf(x):
685+
"""
686+
Return if the input variable has positive infinity element
687+
688+
"""
689+
return eq(x, np.inf)
690+
691+
692+
def isneginf(x):
693+
"""
694+
Return if the input variable has negative infinity element
695+
696+
"""
697+
return eq(x, -np.inf)
698+
699+
684700
@scalar_elemwise
685701
def lt(a, b):
686702
"""a < b"""
@@ -2913,6 +2929,62 @@ def vectorize_node_dot_to_matmul(op, node, batched_x, batched_y):
29132929
return vectorize_node_fallback(op, node, batched_x, batched_y)
29142930

29152931

2932+
def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
2933+
"""
2934+
Replace NaN with zero and infinity with large finite numbers (default
2935+
behaviour) or with the numbers defined by the user using the `nan`,
2936+
`posinf` and/or `neginf` keywords.
2937+
2938+
NaN is replaced by zero or by the user defined value in
2939+
`nan` keyword, infinity is replaced by the largest finite floating point
2940+
values representable by ``x.dtype`` or by the user defined value in
2941+
`posinf` keyword and -infinity is replaced by the most negative finite
2942+
floating point values representable by ``x.dtype`` or by the user defined
2943+
value in `neginf` keyword.
2944+
2945+
Parameters
2946+
----------
2947+
x : symbolic tensor
2948+
Input array.
2949+
nan
2950+
The value to replace NaN's with in the tensor (default = 0).
2951+
posinf
2952+
The value to replace +INF with in the tensor (default max
2953+
in range representable by ``x.dtype``).
2954+
neginf
2955+
The value to replace -INF with in the tensor (default min
2956+
in range representable by ``x.dtype``).
2957+
2958+
Returns
2959+
-------
2960+
out
2961+
The tensor with NaN's, +INF, and -INF replaced with the
2962+
specified and/or default substitutions.
2963+
"""
2964+
# Replace NaN's with nan keyword
2965+
is_nan = isnan(x)
2966+
is_pos_inf = isposinf(x)
2967+
is_neg_inf = isneginf(x)
2968+
2969+
x = switch(is_nan, nan, x)
2970+
2971+
# Get max and min values representable by x.dtype
2972+
maxf = posinf
2973+
minf = neginf
2974+
2975+
# Specify the value to replace +INF and -INF with
2976+
if maxf is None:
2977+
maxf = np.finfo(x.real.dtype).max
2978+
if minf is None:
2979+
minf = np.finfo(x.real.dtype).min
2980+
2981+
# Replace +INF and -INF values
2982+
x = switch(is_pos_inf, maxf, x)
2983+
x = switch(is_neg_inf, minf, x)
2984+
2985+
return x
2986+
2987+
29162988
# NumPy logical aliases
29172989
square = sqr
29182990

@@ -2951,6 +3023,8 @@ def vectorize_node_dot_to_matmul(op, node, batched_x, batched_y):
29513023
"not_equal",
29523024
"isnan",
29533025
"isinf",
3026+
"isposinf",
3027+
"isneginf",
29543028
"allclose",
29553029
"isclose",
29563030
"and_",
@@ -3069,4 +3143,5 @@ def vectorize_node_dot_to_matmul(op, node, batched_x, batched_y):
30693143
"logaddexp",
30703144
"logsumexp",
30713145
"hyp2f1",
3146+
"nan_to_num",
30723147
]

tests/tensor/test_math.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@
8080
isinf,
8181
isnan,
8282
isnan_,
83+
isneginf,
84+
isposinf,
8385
log,
8486
log1mexp,
8587
log1p,
@@ -96,6 +98,7 @@
9698
minimum,
9799
mod,
98100
mul,
101+
nan_to_num,
99102
neg,
100103
neq,
101104
outer,
@@ -3689,3 +3692,46 @@ def test_grad_n_undefined(self):
36893692
n = scalar(dtype="int64")
36903693
with pytest.raises(NullTypeGradError):
36913694
grad(polygamma(n, 0.5), wrt=n)
3695+
3696+
3697+
def test_infs():
3698+
x = tensor(shape=(7,))
3699+
3700+
f_pos = function([x], isposinf(x))
3701+
f_neg = function([x], isneginf(x))
3702+
3703+
y = np.array([1, np.inf, 2, np.inf, -np.inf, -np.inf, 4]).astype(x.dtype)
3704+
out_pos = f_pos(y)
3705+
out_neg = f_neg(y)
3706+
3707+
np.testing.assert_allclose(
3708+
out_pos,
3709+
[0, 1, 0, 1, 0, 0, 0],
3710+
)
3711+
np.testing.assert_allclose(
3712+
out_neg,
3713+
[0, 0, 0, 0, 1, 1, 0],
3714+
)
3715+
3716+
3717+
@pytest.mark.parametrize(
3718+
["nan", "posinf", "neginf"],
3719+
[(0, None, None), (0, 0, 0), (0, None, 1000), (3, 1, -1)],
3720+
)
3721+
def test_nan_to_num(nan, posinf, neginf):
3722+
x = tensor(shape=(7,))
3723+
3724+
out = nan_to_num(x, nan, posinf, neginf)
3725+
3726+
f = function([x], out)
3727+
3728+
y = np.array([1, 2, np.nan, np.inf, -np.inf, 3, 4]).astype(x.dtype)
3729+
out = f(y)
3730+
3731+
posinf = np.finfo(x.real.dtype).max if posinf is None else posinf
3732+
neginf = np.finfo(x.real.dtype).min if neginf is None else neginf
3733+
3734+
np.testing.assert_allclose(
3735+
out,
3736+
np.nan_to_num(y, nan=nan, posinf=posinf, neginf=neginf),
3737+
)

0 commit comments

Comments
 (0)