|
6 | 6 | import pytensor
|
7 | 7 | import pytensor.scalar as ps
|
8 | 8 | import pytensor.tensor as pt
|
9 |
| -from pytensor import shared |
| 9 | +from pytensor import graph_replace, shared |
10 | 10 | from pytensor.compile import optdb
|
11 | 11 | from pytensor.compile.function import function
|
12 | 12 | from pytensor.compile.mode import get_default_mode, get_mode
|
@@ -2010,3 +2010,19 @@ def test_topological_fill_sink_multi_output_client():
|
2010 | 2010 | [new_out] = fg.outputs
|
2011 | 2011 | expected_out = pt.full_like(z, pt.add(*elem_op_with_2_outputs(pt.exp(x))))
|
2012 | 2012 | assert equal_computations([new_out], [expected_out])
|
| 2013 | + |
| 2014 | + |
| 2015 | +def test_topological_fill_sink_broadcastable_change(): |
| 2016 | + """Test rewrite doesn't fail after a graph replacement that provides a broadcastable change.""" |
| 2017 | + a = vector("a", shape=(1,)) |
| 2018 | + b = vector("b", shape=(1,)) |
| 2019 | + zeros = pt.vector("zeros", shape=(None,)) |
| 2020 | + initial_out = pt.full_like(zeros, a) + b |
| 2021 | + |
| 2022 | + # Make broadcast to zeros irrelevant |
| 2023 | + out = graph_replace(initial_out, {zeros: pt.zeros((1,))}, strict=False) |
| 2024 | + |
| 2025 | + fg = FunctionGraph([a, b], [out], copy_inputs=False) |
| 2026 | + topological_fill_sink.rewrite(fg) |
| 2027 | + [new_out] = fg.outputs |
| 2028 | + assert equal_computations([new_out], [a + b]) |
0 commit comments