diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index c49fbadce4..c5ac0a28a3 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -53,6 +53,7 @@ from pytensor.tensor.basic import ( Alloc, AllocEmpty, + atleast_Nd, get_scalar_constant_value, ) from pytensor.tensor.elemwise import DimShuffle, Elemwise @@ -1186,8 +1187,8 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node): return subtensor_merge_replacements -def _is_default_scan_buffer(x: TensorVariable) -> bool: - node = x.owner +def _is_default_scan_buffer(final_buffer: TensorVariable, taps: int) -> bool: + node = final_buffer.owner if node is None: return False @@ -1200,8 +1201,10 @@ def _is_default_scan_buffer(x: TensorVariable) -> bool: ): return False - x, y, *_ = node.inputs - if not (x.owner is not None and isinstance(x.owner.op, AllocEmpty)): + init_buffer, init_value, *_ = node.inputs + if not ( + init_buffer.owner is not None and isinstance(init_buffer.owner.op, AllocEmpty) + ): return False # The value may have been broadcast to fill in the initial taps. @@ -1218,10 +1221,16 @@ def _is_default_scan_buffer(x: TensorVariable) -> bool: # 1. alloc_empty(2 + nsteps)[:2].broadcastable == x.broadcastable # But due to laziness we use the slightly more conservative check: # 2. alloc_empty(2 + nsteps).broadcastable == x.broadcastable - if broadcasted_by(y, x): - return False - - return True + if taps > 1: + return not broadcasted_by(init_value, init_buffer) + else: + # In this case we know we have alloc_empty(1 + nsteps, ...)[:1].set(init_value) + # The first dimension cannot possibly broadcast in the subtensor assignment, + # so we exclude it from `broadcasted_by`. To exclude it we squeeze it out, + # after adding any other implicit expand_dims. We select into the first entry of + # the buffer, to check for potential broadcasting in other dimensions. + init_value_ = atleast_Nd(init_value, n=init_buffer.ndim) + return not broadcasted_by(init_value_.squeeze(0), init_buffer[0]) def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: bool): @@ -1574,15 +1583,16 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: # If the memory for this output has been pre-allocated # before going into the scan op (by an alloc node) if idx < op_info.n_mit_sot + op_info.n_sit_sot: + taps = init_l[i] nw_input = nw_inputs[offset + idx] # Recreate default buffers with new size - if _is_default_scan_buffer(nw_input): - extra_size = 1 if required_orphan else val - init_l[i] + if _is_default_scan_buffer(nw_input, taps): + extra_size = 1 if required_orphan else val - taps nw_input = expand_empty(nw_input.owner.inputs[1], extra_size) # Otherwise, just trim with a slice else: - stop = init_l[i] if required_orphan else val + stop = taps if required_orphan else val nw_input = nw_input[:stop] nw_inputs[offset + idx] = nw_input @@ -1626,14 +1636,13 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: # val == 0 means that we want to keep all intermediate # results for that state, including the initial values. if idx < op_info.n_mit_sot + op_info.n_sit_sot: + taps = init_l[op_info.n_mit_mot + idx] in_idx = offset + idx nw_input = nw_inputs[in_idx] - if _is_default_scan_buffer(nw_input): + if _is_default_scan_buffer(nw_input, taps): nw_input = expand_empty(nw_input.owner.inputs[1], nw_steps) else: - # Number of steps in the initial state - init_l_pt = pt.as_tensor(init_l[op_info.n_mit_mot + idx]) - nw_input = nw_input[: (init_l_pt + nw_steps)] + nw_input = nw_input[: (taps + nw_steps)] nw_inputs[in_idx] = nw_input elif ( diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 59148fae3b..bf5f9d6bd7 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -95,9 +95,11 @@ def broadcasted_by(x: TensorVariable, y: TensorVariable) -> bool: """ bx = x.type.broadcastable by = y.type.broadcastable - if len(bx) < len(by): + bx_len = len(bx) + by_len = len(by) + if bx_len < by_len: return True - bx = bx[-len(by) :] + bx = bx[bx_len - by_len :] return any(bx_dim and not by_dim for bx_dim, by_dim in zip(bx, by, strict=True)) diff --git a/tests/scan/test_rewriting.py b/tests/scan/test_rewriting.py index 1b687afcdc..1b7fac98a4 100644 --- a/tests/scan/test_rewriting.py +++ b/tests/scan/test_rewriting.py @@ -9,13 +9,14 @@ from pytensor.compile.mode import get_default_mode from pytensor.configdefaults import config from pytensor.gradient import grad, jacobian -from pytensor.graph.basic import Constant, equal_computations +from pytensor.graph.basic import Constant, ancestors, equal_computations from pytensor.graph.fg import FunctionGraph from pytensor.graph.replace import clone_replace from pytensor.scan.op import Scan from pytensor.scan.rewriting import ScanInplaceOptimizer, ScanMerge from pytensor.scan.utils import until from pytensor.tensor import stack +from pytensor.tensor.basic import AllocEmpty from pytensor.tensor.blas import Dot22 from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.math import Dot, dot, sigmoid, tanh @@ -1207,7 +1208,7 @@ def test_inplace3(self): class TestSaveMem: - mode = get_default_mode().including("scan_save_mem") + mode = get_default_mode().including("scan_save_mem").excluding("scan_pushout") def test_save_mem(self): rng = np.random.default_rng(utt.fetch_seed()) @@ -1371,7 +1372,7 @@ def test_save_mem_cannot_reduce_constant_number_of_steps(self): ) def test_save_mem_store_steps(self): - def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1): + def step(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1): return ( u_t + 1.0, u_t + 2.0, @@ -1388,7 +1389,7 @@ def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1): x30 = vector("x30") x40 = scalar("x40") [x1, x2, x3, x4, x5, x6, x7], updates = scan( - f_rnn, + step, u, [ None, @@ -1404,7 +1405,7 @@ def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1): go_backwards=False, ) - f2 = function( + f = function( [u, x10, x20, x30, x40], [x1[-7], x2[-3:-1], x3[-6:], x4[-1], x5[-1]], updates=updates, @@ -1417,13 +1418,51 @@ def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1): v_u = rng.uniform(-5.0, 5.0, size=(20,)) # compute the output in numpy - tx1, tx2, tx3, tx4, tx5 = f2(v_u, [0, 0], 0, [0, 0], 0) - - utt.assert_allclose(tx1, v_u[-7] + 1.0) - utt.assert_allclose(tx2, v_u[-3:-1] + 2.0) - utt.assert_allclose(tx3, v_u[-6:] + 3.0) - utt.assert_allclose(tx4, v_u[-1] + 4.0) - utt.assert_allclose(tx5, v_u[-1] + 5.0) + tx1, tx2, tx3, tx4, tx5 = f(v_u, [0, 0], 0, [0, 0], 0) + rtol = 1e-7 if config.floatX == "float64" else 1e-6 + np.testing.assert_allclose(tx1, v_u[-7] + 1.0, rtol=rtol) + np.testing.assert_allclose(tx2, v_u[-3:-1] + 2.0, rtol=rtol) + np.testing.assert_allclose(tx3, v_u[-6:] + 3.0, rtol=rtol) + np.testing.assert_allclose(tx4, v_u[-1] + 4.0, rtol=rtol) + np.testing.assert_allclose(tx5, v_u[-1] + 5.0, rtol=rtol) + + # Confirm reduction in buffer sizes + [scan_node] = [ + node for node in f.maker.fgraph.apply_nodes if isinstance(node.op, Scan) + ] + # x6 and x7 are dropped because they are not used + [n_steps, seq, x4_buffer, x5_buffer, x1_len, x2_len, x3_len] = scan_node.inputs + [x4_underlying_alloc] = [ + var + for var in ancestors([x4_buffer]) + if var.owner and isinstance(var.owner.op, AllocEmpty) + ] + [x5_underlying_alloc] = [ + var + for var in ancestors([x5_buffer]) + if var.owner and isinstance(var.owner.op, AllocEmpty) + ] + buffer_lengths = pytensor.function( + [u, x10, x20, x30, x40], + [ + x1_len, + x2_len, + x3_len, + x4_underlying_alloc.shape[0], + x5_underlying_alloc.shape[0], + ], + accept_inplace=True, + on_unused_input="ignore", + allow_input_downcast=True, + )(v_u, [0, 0], 0, [0, 0], 0) + # ScanSaveMem keeps +1 entries to handle taps with preallocated outputs + assert [int(i) for i in buffer_lengths] == [ + 7, # entry -7 of a map variable is kept, we need at least that many + 3, # entries [-3, -2] of a map variable are kept, we need at least 3 + 6, # last six entries of a map variable are kept + 2 + 1, # last entry of a double tap variable is kept + 1 + 1, # last entry of a single tap variable is kept + ] def test_savemem_does_not_duplicate_number_of_scan_nodes(self): var = pt.ones(())