Skip to content

Commit fc21336

Browse files
committed
Allow fill_sink rewrite to accomodate changes in broadcastability
1 parent a6255d6 commit fc21336

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

pytensor/tensor/rewriting/basic.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -351,10 +351,7 @@ def local_fill_sink(fgraph, node):
351351
# Check if we need to propagate the fill to the new outputs
352352
# It's enough to check the first output, as Elemwise outputs must all have the same shapes
353353
# Note: There are orderings that may require fewer fills.
354-
old_bcast_pattern = node.outputs[0].type.broadcastable
355-
models_iter = iter(models)
356-
while old_bcast_pattern != outputs[0].type.broadcastable:
357-
model = next(models_iter)
354+
for model in models:
358355
# Only apply this model if it would actually do anything
359356
if broadcasted_by(outputs[0], model):
360357
outputs = [fill(model, output) for output in outputs]

tests/tensor/rewriting/test_basic.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytensor
77
import pytensor.scalar as ps
88
import pytensor.tensor as pt
9-
from pytensor import shared
9+
from pytensor import graph_replace, shared
1010
from pytensor.compile import optdb
1111
from pytensor.compile.function import function
1212
from pytensor.compile.mode import get_default_mode, get_mode
@@ -2010,3 +2010,19 @@ def test_topological_fill_sink_multi_output_client():
20102010
[new_out] = fg.outputs
20112011
expected_out = pt.full_like(z, pt.add(*elem_op_with_2_outputs(pt.exp(x))))
20122012
assert equal_computations([new_out], [expected_out])
2013+
2014+
2015+
def test_topological_fill_sink_broadcastable_change():
2016+
"""Test rewrite doesn't fail after a graph replacement that provides a broadcastable change."""
2017+
a = vector("a", shape=(1,))
2018+
b = vector("b", shape=(1,))
2019+
zeros = pt.vector("zeros", shape=(None,))
2020+
initial_out = pt.full_like(zeros, a) + b
2021+
2022+
# Make broadcast to zeros irrelevant
2023+
out = graph_replace(initial_out, {zeros: pt.zeros((1,))}, strict=False)
2024+
2025+
fg = FunctionGraph([a, b], [out], copy_inputs=False)
2026+
topological_fill_sink.rewrite(fg)
2027+
[new_out] = fg.outputs
2028+
assert equal_computations([new_out], [a + b])

0 commit comments

Comments
 (0)