Skip to content

Commit 326cb2e

Browse files
committed
Fail graciously in local_pow_to_nested_squaring when static type shape is updated
1 parent 3169197 commit 326cb2e

File tree

2 files changed

+85
-59
lines changed

2 files changed

+85
-59
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 58 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2081,63 +2081,65 @@ def local_pow_to_nested_squaring(fgraph, node):
20812081
Note: This sounds like the kind of thing any half-decent compiler can do by itself?
20822082
"""
20832083

2084-
if node.op == at_pow:
2085-
# the idea here is that we have pow(x, y)
2086-
odtype = node.outputs[0].dtype
2087-
xsym = node.inputs[0]
2088-
ysym = node.inputs[1]
2089-
y = get_constant(ysym)
2090-
2091-
# the next line is needed to fix a strange case that I don't
2092-
# know how to make a separate test.
2093-
# That happen in the `test_log_erfc` test.
2094-
# y is a ndarray with dtype int8 and value 2,4 or 6. This make
2095-
# the abs(y) <= 512 fail!
2096-
# taking the value outside ndarray solve the problem.
2097-
# it could be that in that case, numpy make the comparison
2098-
# into the wrong type(do in int8 that overflow.)
2099-
if isinstance(y, np.ndarray):
2100-
assert y.size == 1
2101-
try:
2102-
y = y[0]
2103-
except IndexError:
2104-
pass
2105-
if (y is not None) and not broadcasted_by(xsym, ysym):
2106-
rval = None
2107-
# 512 is too small for the cpu and too big for some gpu!
2108-
if abs(y) == int(abs(y)) and abs(y) <= 512:
2109-
pow2 = [xsym]
2110-
pow2_scal = [aes.get_scalar_type(xsym.dtype)()]
2111-
y_to_do = abs(y)
2112-
for i in range(int(np.log2(y_to_do))):
2113-
pow2.append(sqr(pow2[i]))
2114-
pow2_scal.append(aes.sqr(pow2_scal[i]))
2115-
rval1 = None
2116-
rval1_scal = None
2117-
while y_to_do > 0:
2118-
log_to_do = int(np.log2(y_to_do))
2119-
if rval1:
2120-
rval1 *= pow2[log_to_do]
2121-
rval1_scal *= pow2_scal[log_to_do]
2122-
else:
2123-
rval1 = pow2[log_to_do]
2124-
rval1_scal = pow2_scal[log_to_do]
2125-
y_to_do -= 2**log_to_do
2126-
2127-
if abs(y) > 2:
2128-
# We fuse all the pow together here to make
2129-
# compilation faster
2130-
rval1 = Elemwise(
2131-
aes.Composite([pow2_scal[0]], [rval1_scal])
2132-
).make_node(xsym)
2133-
if y < 0:
2134-
rval = [reciprocal(rval1)]
2084+
# the idea here is that we have pow(x, y)
2085+
odtype = node.outputs[0].dtype
2086+
xsym = node.inputs[0]
2087+
ysym = node.inputs[1]
2088+
y = get_constant(ysym)
2089+
2090+
# the next line is needed to fix a strange case that I don't
2091+
# know how to make a separate test.
2092+
# That happen in the `test_log_erfc` test.
2093+
# y is a ndarray with dtype int8 and value 2,4 or 6. This make
2094+
# the abs(y) <= 512 fail!
2095+
# taking the value outside ndarray solve the problem.
2096+
# it could be that in that case, numpy make the comparison
2097+
# into the wrong type(do in int8 that overflow.)
2098+
if isinstance(y, np.ndarray):
2099+
assert y.size == 1
2100+
try:
2101+
y = y[0]
2102+
except IndexError:
2103+
pass
2104+
if (y is not None) and not broadcasted_by(xsym, ysym):
2105+
rval = None
2106+
# 512 is too small for the cpu and too big for some gpu!
2107+
if abs(y) == int(abs(y)) and abs(y) <= 512:
2108+
pow2 = [xsym]
2109+
pow2_scal = [aes.get_scalar_type(xsym.dtype)()]
2110+
y_to_do = abs(y)
2111+
for i in range(int(np.log2(y_to_do))):
2112+
pow2.append(sqr(pow2[i]))
2113+
pow2_scal.append(aes.sqr(pow2_scal[i]))
2114+
rval1 = None
2115+
rval1_scal = None
2116+
while y_to_do > 0:
2117+
log_to_do = int(np.log2(y_to_do))
2118+
if rval1:
2119+
rval1 *= pow2[log_to_do]
2120+
rval1_scal *= pow2_scal[log_to_do]
21352121
else:
2136-
rval = [rval1]
2137-
if rval:
2138-
rval[0] = cast(rval[0], odtype)
2139-
assert rval[0].type == node.outputs[0].type, (rval, node.outputs)
2140-
return rval
2122+
rval1 = pow2[log_to_do]
2123+
rval1_scal = pow2_scal[log_to_do]
2124+
y_to_do -= 2**log_to_do
2125+
2126+
if abs(y) > 2:
2127+
# We fuse all the pow together here to make
2128+
# compilation faster
2129+
rval1 = Elemwise(aes.Composite([pow2_scal[0]], [rval1_scal])).make_node(
2130+
xsym
2131+
)
2132+
if y < 0:
2133+
rval = [reciprocal(rval1)]
2134+
else:
2135+
rval = [rval1]
2136+
if rval:
2137+
rval[0] = cast(rval[0], odtype)
2138+
# TODO: We can add a specify_broadcastable and/or unbroadcast to make the
2139+
# output types compatible. Or work on #408 and let TensorType.filter_variable do it.
2140+
if rval[0].type.broadcastable != node.outputs[0].type.broadcastable:
2141+
return None
2142+
return rval
21412143

21422144

21432145
@register_specialize

tests/tensor/rewriting/test_math.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@
2929
from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph
3030
from pytensor.misc.safe_asarray import _asarray
3131
from pytensor.printing import debugprint
32+
from pytensor.scalar import Pow
3233
from pytensor.tensor import inplace
33-
from pytensor.tensor.basic import Alloc, join, second, switch
34+
from pytensor.tensor.basic import Alloc, constant, join, second, switch
3435
from pytensor.tensor.blas import Dot22, Gemv
3536
from pytensor.tensor.blas_c import CGemv
3637
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
@@ -69,7 +70,7 @@
6970
from pytensor.tensor.math import maximum
7071
from pytensor.tensor.math import min as at_min
7172
from pytensor.tensor.math import minimum, mul, neg, neq
72-
from pytensor.tensor.math import pow as at_pow
73+
from pytensor.tensor.math import pow as pt_pow
7374
from pytensor.tensor.math import (
7475
prod,
7576
rad2deg,
@@ -1746,6 +1747,29 @@ def test_local_pow_to_nested_squaring():
17461747
utt.assert_allclose(f(val_no0), val_no0 ** (-16))
17471748

17481749

1750+
def test_local_pow_to_nested_squaring_fails_gracefully():
1751+
# Reported in #456
1752+
1753+
x = vector("x", shape=(1,))
1754+
# Create an Apply that does not have precise output shape
1755+
node = Apply(
1756+
op=pt_pow,
1757+
inputs=[x, constant([2.0])],
1758+
outputs=[tensor(shape=(None,))],
1759+
)
1760+
y = node.default_output()
1761+
1762+
fn = function([x], y)
1763+
1764+
# Check rewrite is not applied (this could change in the future)
1765+
assert any(
1766+
(isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Pow))
1767+
for node in fn.maker.fgraph.apply_nodes
1768+
)
1769+
1770+
np.testing.assert_allclose(fn([2.0]), np.array([4.0]))
1771+
1772+
17491773
class TestFuncInverse:
17501774
def setup_method(self):
17511775
mode = get_default_mode()
@@ -2449,7 +2473,7 @@ def test_elemwise(self):
24492473
le,
24502474
eq,
24512475
neq,
2452-
at_pow,
2476+
pt_pow,
24532477
):
24542478
g = rewrite(FunctionGraph(mats, [op(s1, s2)]))
24552479
assert debugprint(g, file="str").count("Switch") == 1

0 commit comments

Comments
 (0)