Skip to content

Fix miscelaneous bugs #368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ def signature(self):
return (self.type, self.data)

def __str__(self):
data_str = str(self.data)
data_str = str(self.data).replace("\n", "")
if len(data_str) > 20:
data_str = data_str[:10].strip() + " ... " + data_str[-10:].strip()

Expand Down
1 change: 0 additions & 1 deletion pytensor/graph/rewriting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def rewrite_graph(

return_fgraph = False
if isinstance(graph, FunctionGraph):
outputs: Sequence[Variable] = graph.outputs
fgraph = graph
return_fgraph = True
else:
Expand Down
7 changes: 5 additions & 2 deletions pytensor/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ def __init__(self, input_broadcastable, new_order):
super().__init__([self.c_func_file], self.c_func_name)

self.input_broadcastable = tuple(input_broadcastable)
if not all(isinstance(bs, (bool, np.bool_)) for bs in self.input_broadcastable):
raise ValueError(
f"input_broadcastable must be boolean, {self.input_broadcastable}"
)
self.new_order = tuple(new_order)

self.inplace = True
Expand Down Expand Up @@ -411,10 +415,9 @@ def get_output_info(self, dim_shuffle, *inputs):
if not difference:
args.append(input)
else:
# TODO: use LComplete instead
args.append(
dim_shuffle(
tuple(1 if s == 1 else None for s in input.type.shape),
input.type.broadcastable,
["x"] * difference + list(range(length)),
)(input)
)
Expand Down
5 changes: 4 additions & 1 deletion pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2072,7 +2072,10 @@ def local_pow_specialize(fgraph, node):
rval = [reciprocal(sqr(xsym))]
if rval:
rval[0] = cast(rval[0], odtype)
assert rval[0].type == node.outputs[0].type, (rval, node.outputs)
assert rval[0].type.is_super(node.outputs[0].type), (
rval[0].type,
node.outputs[0].type,
)
return rval
else:
return False
Expand Down
14 changes: 13 additions & 1 deletion tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
perform_sigm_times_exp,
simplify_mul,
)
from pytensor.tensor.shape import Reshape, Shape_i
from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape
from pytensor.tensor.type import (
TensorType,
cmatrix,
Expand Down Expand Up @@ -1671,6 +1671,18 @@ def test_local_pow_specialize():
assert isinstance(nodes[1].scalar_op, aes.basic.Reciprocal)
utt.assert_allclose(f(val_no0), val_no0 ** (-0.5))

twos = np.full(shape=(10,), fill_value=2.0).astype(config.floatX)
f = function([v], v**twos, mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
# Depending on the mode the SpecifyShape is lifted or not
if topo[0].op == sqr:
assert isinstance(topo[1].op, SpecifyShape)
else:
assert isinstance(topo[0].op, SpecifyShape)
assert topo[1].op == sqr
utt.assert_allclose(f(val), val**twos)


def test_local_pow_specialize_device_more_aggressive_on_cpu():
mode = config.mode
Expand Down
6 changes: 6 additions & 0 deletions tests/tensor/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ def test_static_shape(self):
y = x.dimshuffle([0, 1, "x"])
assert y.type.shape == (1, 2, 1)

def test_valid_input_broadcastable(self):
assert DimShuffle([True, False], (1, 0)).input_broadcastable == (True, False)

with pytest.raises(ValueError, match="input_broadcastable must be boolean"):
DimShuffle([None, None], (1, 0))


class TestBroadcast:
# this is to allow other types to reuse this class to test their ops
Expand Down