-
Notifications
You must be signed in to change notification settings - Fork 135
Closed
Labels
Description
Description
import numpy as np
import pytensor as pt
pt.get_scalar_constant_value(pt.constant(np.zeros(5))) # 0.0; wrong!
pt.get_scalar_constant_value(pt.constant(np.arange(5))) # raises NotScalarConstantError; correct
This makes this rewrite fail unnecessarily sometimes:
pytensor/pytensor/tensor/rewriting/elemwise.py
Lines 468 to 536 in d7d20be
def local_upcast_elemwise_constant_inputs(fgraph, node): | |
"""This explicitly upcasts constant inputs to elemwise Ops, when | |
those Ops do implicit upcasting anyway. | |
Rationale: it helps merge things like (1-x) and (1.0 - x). | |
""" | |
if len(node.outputs) > 1: | |
return | |
try: | |
shape_i = fgraph.shape_feature.shape_i | |
except AttributeError: | |
shape_i = None | |
if isinstance(node.op, Elemwise): | |
scalar_op = node.op.scalar_op | |
# print "aa", scalar_op.output_types_preference | |
if getattr(scalar_op, "output_types_preference", None) in ( | |
aes.upgrade_to_float, | |
aes.upcast_out, | |
): | |
# this is the kind of op that we can screw with the input | |
# dtypes by upcasting explicitly | |
output_dtype = node.outputs[0].type.dtype | |
new_inputs = [] | |
for i in node.inputs: | |
if i.type.dtype == output_dtype: | |
new_inputs.append(i) | |
else: | |
try: | |
# works only for scalars | |
cval_i = get_scalar_constant_value( | |
i, only_process_constants=True | |
) | |
if all(i.broadcastable): | |
new_inputs.append( | |
shape_padleft(cast(cval_i, output_dtype), i.ndim) | |
) | |
else: | |
if shape_i is None: | |
return | |
new_inputs.append( | |
alloc( | |
cast(cval_i, output_dtype), | |
*[shape_i(d)(i) for d in range(i.ndim)], | |
) | |
) | |
# print >> sys.stderr, "AAA", | |
# *[Shape_i(d)(i) for d in range(i.ndim)] | |
except NotScalarConstantError: | |
# for the case of a non-scalar | |
if isinstance(i, TensorConstant): | |
new_inputs.append(cast(i, output_dtype)) | |
else: | |
new_inputs.append(i) | |
if new_inputs != node.inputs: | |
rval = [node.op(*new_inputs)] | |
if not node.outputs[0].type.is_super(rval[0].type): | |
# This can happen for example when floatX=float32 | |
# and we do the true division between and int64 | |
# and a constant that will get typed as int8. | |
# As this is just to allow merging more case, if | |
# the upcast don't work, we can just skip it. | |
return | |
# Copy over output stacktrace from before upcasting | |
copy_stack_trace(node.outputs[0], rval) | |
return rval |