We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
local_pow_specialize
1 parent 4459199 commit 8ac8342Copy full SHA for 8ac8342
pytensor/tensor/rewriting/math.py
@@ -2072,11 +2072,10 @@ def local_pow_specialize(fgraph, node):
2072
if np.all(y == -2):
2073
rval = [reciprocal(sqr(xsym))]
2074
if rval:
2075
+ if not rval[0].type.broadcastable == node.outputs[0].type.broadcastable:
2076
+ return None
2077
rval[0] = cast(rval[0], odtype)
- assert rval[0].type.is_super(node.outputs[0].type), (
- rval[0].type,
2078
- node.outputs[0].type,
2079
- )
+ assert rval[0].type.dtype == node.outputs[0].type.dtype
2080
return rval
2081
else:
2082
return False
0 commit comments