diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index da324d5df8..e71fb271de 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -1339,7 +1339,7 @@ def save_mem_new_scan(fgraph, node): stop = at.extract_constant(cf_slice[0].stop) else: stop = at.extract_constant(cf_slice[0]) + 1 - if stop == maxsize or stop == length: + if stop == maxsize or stop == at.extract_constant(length): stop = None else: # there is a **gotcha** here ! Namely, scan returns an @@ -1516,13 +1516,17 @@ def save_mem_new_scan(fgraph, node): if ( nw_inputs[offset + idx].owner and isinstance(nw_inputs[offset + idx].owner.op, IncSubtensor) + and nw_inputs[offset + idx].owner.op.set_instead_of_inc and isinstance( nw_inputs[offset + idx].owner.op.idx_list[0], slice ) - ): - assert isinstance( - nw_inputs[offset + idx].owner.op, IncSubtensor + # Don't try to create a smart Alloc, if set_subtensor is broadcasting the fill value + # As it happens in set_subtensor(empty(2)[:], 0) + and not ( + nw_inputs[offset + idx].ndim + > nw_inputs[offset + idx].owner.inputs[1].ndim ) + ): _nw_input = nw_inputs[offset + idx].owner.inputs[1] cval = at.as_tensor_variable(val) initl = at.as_tensor_variable(init_l[i]) diff --git a/tests/scan/test_rewriting.py b/tests/scan/test_rewriting.py index da53093b12..21b23b3433 100644 --- a/tests/scan/test_rewriting.py +++ b/tests/scan/test_rewriting.py @@ -1487,6 +1487,22 @@ def test_while_scan_taps_and_map(self): assert stored_ys_steps == 2 assert stored_zs_steps == 1 + def test_vector_zeros_init(self): + ys, _ = pytensor.scan( + fn=lambda ytm2, ytm1: ytm1 + ytm2, + outputs_info=[{"initial": at.zeros(2), "taps": range(-2, 0)}], + n_steps=100, + ) + + fn = pytensor.function([], ys[-50:], mode=self.mode) + assert tuple(fn().shape) == (50,) + + # Check that rewrite worked + [scan_node] = (n for n in fn.maker.fgraph.apply_nodes if isinstance(n.op, Scan)) + _, ys_trace = scan_node.inputs + debug_fn = pytensor.function([], ys_trace.shape[0], accept_inplace=True) + assert debug_fn() == 50 + def test_inner_replace_dot(): """