diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 90d054f6ba..0f96556475 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -773,7 +773,7 @@ def signature(self): return (self.type, self.data) def __str__(self): - data_str = str(self.data) + data_str = str(self.data).replace("\n", "") if len(data_str) > 20: data_str = data_str[:10].strip() + " ... " + data_str[-10:].strip() diff --git a/pytensor/graph/rewriting/utils.py b/pytensor/graph/rewriting/utils.py index 8bf8de87bb..63cc436396 100644 --- a/pytensor/graph/rewriting/utils.py +++ b/pytensor/graph/rewriting/utils.py @@ -45,7 +45,6 @@ def rewrite_graph( return_fgraph = False if isinstance(graph, FunctionGraph): - outputs: Sequence[Variable] = graph.outputs fgraph = graph return_fgraph = True else: diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 6d19579030..0687d30f10 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -130,6 +130,10 @@ def __init__(self, input_broadcastable, new_order): super().__init__([self.c_func_file], self.c_func_name) self.input_broadcastable = tuple(input_broadcastable) + if not all(isinstance(bs, (bool, np.bool_)) for bs in self.input_broadcastable): + raise ValueError( + f"input_broadcastable must be boolean, {self.input_broadcastable}" + ) self.new_order = tuple(new_order) self.inplace = True @@ -411,10 +415,9 @@ def get_output_info(self, dim_shuffle, *inputs): if not difference: args.append(input) else: - # TODO: use LComplete instead args.append( dim_shuffle( - tuple(1 if s == 1 else None for s in input.type.shape), + input.type.broadcastable, ["x"] * difference + list(range(length)), )(input) ) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index b0d124f1c8..46b98a1575 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -2072,7 +2072,10 @@ def local_pow_specialize(fgraph, node): rval = [reciprocal(sqr(xsym))] if rval: rval[0] = cast(rval[0], odtype) - assert rval[0].type == node.outputs[0].type, (rval, node.outputs) + assert rval[0].type.is_super(node.outputs[0].type), ( + rval[0].type, + node.outputs[0].type, + ) return rval else: return False diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index f69879a51d..4e9d143a8e 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -96,7 +96,7 @@ perform_sigm_times_exp, simplify_mul, ) -from pytensor.tensor.shape import Reshape, Shape_i +from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape from pytensor.tensor.type import ( TensorType, cmatrix, @@ -1671,6 +1671,18 @@ def test_local_pow_specialize(): assert isinstance(nodes[1].scalar_op, aes.basic.Reciprocal) utt.assert_allclose(f(val_no0), val_no0 ** (-0.5)) + twos = np.full(shape=(10,), fill_value=2.0).astype(config.floatX) + f = function([v], v**twos, mode=mode) + topo = f.maker.fgraph.toposort() + assert len(topo) == 2 + # Depending on the mode the SpecifyShape is lifted or not + if topo[0].op == sqr: + assert isinstance(topo[1].op, SpecifyShape) + else: + assert isinstance(topo[0].op, SpecifyShape) + assert topo[1].op == sqr + utt.assert_allclose(f(val), val**twos) + def test_local_pow_specialize_device_more_aggressive_on_cpu(): mode = config.mode diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 40e7db879c..fa820f062a 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -188,6 +188,12 @@ def test_static_shape(self): y = x.dimshuffle([0, 1, "x"]) assert y.type.shape == (1, 2, 1) + def test_valid_input_broadcastable(self): + assert DimShuffle([True, False], (1, 0)).input_broadcastable == (True, False) + + with pytest.raises(ValueError, match="input_broadcastable must be boolean"): + DimShuffle([None, None], (1, 0)) + class TestBroadcast: # this is to allow other types to reuse this class to test their ops