Skip to content

get_scalar_constant does not raise for homogenous non-scalar constants #226

Closed
@ricardoV94

Description

@ricardoV94

Description

import numpy as np
import pytensor as pt

pt.get_scalar_constant_value(pt.constant(np.zeros(5)))  # 0.0;  wrong!
pt.get_scalar_constant_value(pt.constant(np.arange(5)))  # raises NotScalarConstantError; correct

This makes this rewrite fail unnecessarily sometimes:

def local_upcast_elemwise_constant_inputs(fgraph, node):
"""This explicitly upcasts constant inputs to elemwise Ops, when
those Ops do implicit upcasting anyway.
Rationale: it helps merge things like (1-x) and (1.0 - x).
"""
if len(node.outputs) > 1:
return
try:
shape_i = fgraph.shape_feature.shape_i
except AttributeError:
shape_i = None
if isinstance(node.op, Elemwise):
scalar_op = node.op.scalar_op
# print "aa", scalar_op.output_types_preference
if getattr(scalar_op, "output_types_preference", None) in (
aes.upgrade_to_float,
aes.upcast_out,
):
# this is the kind of op that we can screw with the input
# dtypes by upcasting explicitly
output_dtype = node.outputs[0].type.dtype
new_inputs = []
for i in node.inputs:
if i.type.dtype == output_dtype:
new_inputs.append(i)
else:
try:
# works only for scalars
cval_i = get_scalar_constant_value(
i, only_process_constants=True
)
if all(i.broadcastable):
new_inputs.append(
shape_padleft(cast(cval_i, output_dtype), i.ndim)
)
else:
if shape_i is None:
return
new_inputs.append(
alloc(
cast(cval_i, output_dtype),
*[shape_i(d)(i) for d in range(i.ndim)],
)
)
# print >> sys.stderr, "AAA",
# *[Shape_i(d)(i) for d in range(i.ndim)]
except NotScalarConstantError:
# for the case of a non-scalar
if isinstance(i, TensorConstant):
new_inputs.append(cast(i, output_dtype))
else:
new_inputs.append(i)
if new_inputs != node.inputs:
rval = [node.op(*new_inputs)]
if not node.outputs[0].type.is_super(rval[0].type):
# This can happen for example when floatX=float32
# and we do the true division between and int64
# and a constant that will get typed as int8.
# As this is just to allow merging more case, if
# the upcast don't work, we can just skip it.
return
# Copy over output stacktrace from before upcasting
copy_stack_trace(node.outputs[0], rval)
return rval

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions