Skip to content

Commit b8e26cd

Browse files
committed
Make local_pow_to_nested_squaring more permissive
1 parent e48ff56 commit b8e26cd

File tree

2 files changed

+2
-12
lines changed

2 files changed

+2
-12
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2120,10 +2120,6 @@ def local_pow_to_nested_squaring(fgraph, node):
21202120
rval = [rval1]
21212121
if rval:
21222122
rval[0] = cast(rval[0], odtype)
2123-
# TODO: We can add a specify_broadcastable and/or unbroadcast to make the
2124-
# output types compatible. Or work on #408 and let TensorType.filter_variable do it.
2125-
if rval[0].type.broadcastable != node.outputs[0].type.broadcastable:
2126-
return None
21272123
return rval
21282124

21292125

tests/tensor/rewriting/test_math.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph
3030
from pytensor.misc.safe_asarray import _asarray
3131
from pytensor.printing import debugprint
32-
from pytensor.scalar import PolyGamma, Pow, Psi, TriGamma
32+
from pytensor.scalar import PolyGamma, Psi, TriGamma
3333
from pytensor.tensor import inplace
3434
from pytensor.tensor.basic import Alloc, constant, join, second, switch
3535
from pytensor.tensor.blas import Dot22, Gemv
@@ -1757,7 +1757,7 @@ def test_local_pow_to_nested_squaring():
17571757
utt.assert_allclose(f(val_no0), val_no0 ** (-16))
17581758

17591759

1760-
def test_local_pow_to_nested_squaring_fails_gracefully():
1760+
def test_local_pow_to_nested_squaring_works_with_static_type():
17611761
# Reported in #456
17621762

17631763
x = vector("x", shape=(1,))
@@ -1771,12 +1771,6 @@ def test_local_pow_to_nested_squaring_fails_gracefully():
17711771

17721772
fn = function([x], y)
17731773

1774-
# Check rewrite is not applied (this could change in the future)
1775-
assert any(
1776-
(isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Pow))
1777-
for node in fn.maker.fgraph.apply_nodes
1778-
)
1779-
17801774
np.testing.assert_allclose(fn([2.0]), np.array([4.0]))
17811775

17821776

0 commit comments

Comments
 (0)