Skip to content

Fail graciously in local_pow_to_nested_squaring when static type shape is updated #461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 58 additions & 56 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2081,63 +2081,65 @@ def local_pow_to_nested_squaring(fgraph, node):
Note: This sounds like the kind of thing any half-decent compiler can do by itself?
"""

if node.op == at_pow:
# the idea here is that we have pow(x, y)
odtype = node.outputs[0].dtype
xsym = node.inputs[0]
ysym = node.inputs[1]
y = get_constant(ysym)

# the next line is needed to fix a strange case that I don't
# know how to make a separate test.
# That happen in the `test_log_erfc` test.
# y is a ndarray with dtype int8 and value 2,4 or 6. This make
# the abs(y) <= 512 fail!
# taking the value outside ndarray solve the problem.
# it could be that in that case, numpy make the comparison
# into the wrong type(do in int8 that overflow.)
if isinstance(y, np.ndarray):
assert y.size == 1
try:
y = y[0]
except IndexError:
pass
if (y is not None) and not broadcasted_by(xsym, ysym):
rval = None
# 512 is too small for the cpu and too big for some gpu!
if abs(y) == int(abs(y)) and abs(y) <= 512:
pow2 = [xsym]
pow2_scal = [aes.get_scalar_type(xsym.dtype)()]
y_to_do = abs(y)
for i in range(int(np.log2(y_to_do))):
pow2.append(sqr(pow2[i]))
pow2_scal.append(aes.sqr(pow2_scal[i]))
rval1 = None
rval1_scal = None
while y_to_do > 0:
log_to_do = int(np.log2(y_to_do))
if rval1:
rval1 *= pow2[log_to_do]
rval1_scal *= pow2_scal[log_to_do]
else:
rval1 = pow2[log_to_do]
rval1_scal = pow2_scal[log_to_do]
y_to_do -= 2**log_to_do

if abs(y) > 2:
# We fuse all the pow together here to make
# compilation faster
rval1 = Elemwise(
aes.Composite([pow2_scal[0]], [rval1_scal])
).make_node(xsym)
if y < 0:
rval = [reciprocal(rval1)]
# the idea here is that we have pow(x, y)
odtype = node.outputs[0].dtype
xsym = node.inputs[0]
ysym = node.inputs[1]
y = get_constant(ysym)

# the next line is needed to fix a strange case that I don't
# know how to make a separate test.
# That happen in the `test_log_erfc` test.
# y is a ndarray with dtype int8 and value 2,4 or 6. This make
# the abs(y) <= 512 fail!
# taking the value outside ndarray solve the problem.
# it could be that in that case, numpy make the comparison
# into the wrong type(do in int8 that overflow.)
if isinstance(y, np.ndarray):
assert y.size == 1
try:
y = y[0]
except IndexError:
pass
if (y is not None) and not broadcasted_by(xsym, ysym):
rval = None
# 512 is too small for the cpu and too big for some gpu!
if abs(y) == int(abs(y)) and abs(y) <= 512:
pow2 = [xsym]
pow2_scal = [aes.get_scalar_type(xsym.dtype)()]
y_to_do = abs(y)
for i in range(int(np.log2(y_to_do))):
pow2.append(sqr(pow2[i]))
pow2_scal.append(aes.sqr(pow2_scal[i]))
rval1 = None
rval1_scal = None
while y_to_do > 0:
log_to_do = int(np.log2(y_to_do))
if rval1:
rval1 *= pow2[log_to_do]
rval1_scal *= pow2_scal[log_to_do]
else:
rval = [rval1]
if rval:
rval[0] = cast(rval[0], odtype)
assert rval[0].type == node.outputs[0].type, (rval, node.outputs)
return rval
rval1 = pow2[log_to_do]
rval1_scal = pow2_scal[log_to_do]
y_to_do -= 2**log_to_do

if abs(y) > 2:
# We fuse all the pow together here to make
# compilation faster
rval1 = Elemwise(aes.Composite([pow2_scal[0]], [rval1_scal])).make_node(
xsym
)
if y < 0:
rval = [reciprocal(rval1)]
else:
rval = [rval1]
if rval:
rval[0] = cast(rval[0], odtype)
# TODO: We can add a specify_broadcastable and/or unbroadcast to make the
# output types compatible. Or work on #408 and let TensorType.filter_variable do it.
if rval[0].type.broadcastable != node.outputs[0].type.broadcastable:
return None
return rval


@register_specialize
Expand Down
30 changes: 27 additions & 3 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@
from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph
from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import debugprint
from pytensor.scalar import Pow
from pytensor.tensor import inplace
from pytensor.tensor.basic import Alloc, join, second, switch
from pytensor.tensor.basic import Alloc, constant, join, second, switch
from pytensor.tensor.blas import Dot22, Gemv
from pytensor.tensor.blas_c import CGemv
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
Expand Down Expand Up @@ -69,7 +70,7 @@
from pytensor.tensor.math import maximum
from pytensor.tensor.math import min as at_min
from pytensor.tensor.math import minimum, mul, neg, neq
from pytensor.tensor.math import pow as at_pow
from pytensor.tensor.math import pow as pt_pow
from pytensor.tensor.math import (
prod,
rad2deg,
Expand Down Expand Up @@ -1746,6 +1747,29 @@ def test_local_pow_to_nested_squaring():
utt.assert_allclose(f(val_no0), val_no0 ** (-16))


def test_local_pow_to_nested_squaring_fails_gracefully():
# Reported in #456

x = vector("x", shape=(1,))
# Create an Apply that does not have precise output shape
node = Apply(
op=pt_pow,
inputs=[x, constant([2.0])],
outputs=[tensor(shape=(None,))],
)
y = node.default_output()

fn = function([x], y)

# Check rewrite is not applied (this could change in the future)
assert any(
(isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Pow))
for node in fn.maker.fgraph.apply_nodes
)

np.testing.assert_allclose(fn([2.0]), np.array([4.0]))


class TestFuncInverse:
def setup_method(self):
mode = get_default_mode()
Expand Down Expand Up @@ -2449,7 +2473,7 @@ def test_elemwise(self):
le,
eq,
neq,
at_pow,
pt_pow,
):
g = rewrite(FunctionGraph(mats, [op(s1, s2)]))
assert debugprint(g, file="str").count("Switch") == 1
Expand Down