Skip to content

Commit 64e60ed

Browse files
Replaced use of isposinf and isneginf op with existing ops
1 parent 829309b commit 64e60ed

File tree

2 files changed

+2
-92
lines changed

2 files changed

+2
-92
lines changed

pytensor/scalar/basic.py

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1533,56 +1533,6 @@ def c_code_cache_version(self):
15331533
isinf = IsInf()
15341534

15351535

1536-
class IsPosInf(FixedLogicalComparison):
1537-
nfunc_spec = ("isposinf", 1, 1)
1538-
1539-
def impl(self, x):
1540-
return np.isposinf(x)
1541-
1542-
def c_code(self, node, name, inputs, outputs, sub):
1543-
(x,) = inputs
1544-
(z,) = outputs
1545-
if node.inputs[0].type in complex_types:
1546-
raise NotImplementedError()
1547-
# Discrete type can never be posinf
1548-
if node.inputs[0].type in discrete_types:
1549-
return f"{z} = false;"
1550-
1551-
return f"{z} = isinf({x}) && !signbit({x});"
1552-
1553-
def c_code_cache_version(self):
1554-
scalarop_version = super().c_code_cache_version()
1555-
return (*scalarop_version, 4)
1556-
1557-
1558-
isposinf = IsPosInf()
1559-
1560-
1561-
class IsNegInf(FixedLogicalComparison):
1562-
nfunc_spec = ("isneginf", 1, 1)
1563-
1564-
def impl(self, x):
1565-
return np.isneginf(x)
1566-
1567-
def c_code(self, node, name, inputs, outputs, sub):
1568-
(x,) = inputs
1569-
(z,) = outputs
1570-
if node.inputs[0].type in complex_types:
1571-
raise NotImplementedError()
1572-
# Discrete type can never be neginf
1573-
if node.inputs[0].type in discrete_types:
1574-
return f"{z} = false;"
1575-
1576-
return f"{z} = isinf({x}) && signbit({x});"
1577-
1578-
def c_code_cache_version(self):
1579-
scalarop_version = super().c_code_cache_version()
1580-
return (*scalarop_version, 4)
1581-
1582-
1583-
isneginf = IsNegInf()
1584-
1585-
15861536
class InRange(LogicalComparison):
15871537
nin = 3
15881538

pytensor/tensor/math.py

Lines changed: 2 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -881,46 +881,6 @@ def isinf(a):
881881
return isinf_(a)
882882

883883

884-
@scalar_elemwise
885-
def isposinf(a):
886-
"""isposinf(a)"""
887-
888-
889-
# Rename isposnan to isposnan_ to allow to bypass it when not needed.
890-
# glibc 2.23 don't allow isposnan on int, so we remove it from the graph.
891-
isposinf_ = isposinf
892-
893-
894-
def isposinf(a):
895-
"""isposinf(a)"""
896-
a = as_tensor_variable(a)
897-
if a.dtype in discrete_dtypes:
898-
return alloc(
899-
np.asarray(False, dtype="bool"), *[a.shape[i] for i in range(a.ndim)]
900-
)
901-
return isposinf_(a)
902-
903-
904-
@scalar_elemwise
905-
def isneginf(a):
906-
"""isneginf(a)"""
907-
908-
909-
# Rename isnegnan to isnegnan_ to allow to bypass it when not needed.
910-
# glibc 2.23 don't allow isnegnan on int, so we remove it from the graph.
911-
isneginf_ = isneginf
912-
913-
914-
def isneginf(a):
915-
"""isneginf(a)"""
916-
a = as_tensor_variable(a)
917-
if a.dtype in discrete_dtypes:
918-
return alloc(
919-
np.asarray(False, dtype="bool"), *[a.shape[i] for i in range(a.ndim)]
920-
)
921-
return isneginf_(a)
922-
923-
924884
def allclose(a, b, rtol=1.0e-5, atol=1.0e-8, equal_nan=False):
925885
"""
926886
Implement Numpy's ``allclose`` on tensors.
@@ -3117,8 +3077,8 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
31173077
"""
31183078
# Replace NaN's with nan keyword
31193079
is_nan = isnan(x)
3120-
is_pos_inf = isposinf(x)
3121-
is_neg_inf = isneginf(x)
3080+
is_pos_inf = eq(x, np.inf)
3081+
is_neg_inf = eq(x, -np.inf)
31223082

31233083
if not any(is_nan) and not any(is_pos_inf) and not any(is_neg_inf):
31243084
return

0 commit comments

Comments
 (0)