Skip to content

Commit 8ac8342

Browse files
committed
Fix overly strict check in local_pow_specialize rewrite
1 parent 4459199 commit 8ac8342

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2072,11 +2072,10 @@ def local_pow_specialize(fgraph, node):
20722072
if np.all(y == -2):
20732073
rval = [reciprocal(sqr(xsym))]
20742074
if rval:
2075+
if not rval[0].type.broadcastable == node.outputs[0].type.broadcastable:
2076+
return None
20752077
rval[0] = cast(rval[0], odtype)
2076-
assert rval[0].type.is_super(node.outputs[0].type), (
2077-
rval[0].type,
2078-
node.outputs[0].type,
2079-
)
2078+
assert rval[0].type.dtype == node.outputs[0].type.dtype
20802079
return rval
20812080
else:
20822081
return False

0 commit comments

Comments
 (0)