diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 82931bced6..ffa27e5d5a 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -454,6 +454,19 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): RewriteDatabaseQuery(include=["fast_run", "py_only"]), ) +NUMBA = Mode( + NumbaLinker(), + RewriteDatabaseQuery( + include=["fast_run", "numba"], + exclude=[ + "cxx_only", + "BlasOpt", + "local_careduce_fusion", + "scan_save_mem_prealloc", + ], + ), +) + JAX = Mode( JAXLinker(), RewriteDatabaseQuery( @@ -463,6 +476,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): "BlasOpt", "fusion", "inplace", + "scan_save_mem_prealloc", ], ), ) @@ -476,16 +490,10 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): "fusion", "inplace", "local_uint_constant_indices", + "scan_save_mem_prealloc", ], ), ) -NUMBA = Mode( - NumbaLinker(), - RewriteDatabaseQuery( - include=["fast_run", "numba"], - exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"], - ), -) predefined_modes = { diff --git a/pytensor/configdefaults.py b/pytensor/configdefaults.py index 8ad3fa61f4..ca3c44bf6d 100644 --- a/pytensor/configdefaults.py +++ b/pytensor/configdefaults.py @@ -1085,7 +1085,9 @@ def add_scan_configvars(): "scan__allow_output_prealloc", "Allow/disallow memory preallocation for outputs inside of scan " "(default: True)", - BoolParam(True), + # Non-mutable because ScanSaveMem rewrite checks it, + # and we can't have the rewrite and the implementation mismatch + BoolParam(True, mutable=False), in_c_key=False, ) diff --git a/pytensor/link/jax/dispatch/scan.py b/pytensor/link/jax/dispatch/scan.py index b638570bd1..7ff939b43f 100644 --- a/pytensor/link/jax/dispatch/scan.py +++ b/pytensor/link/jax/dispatch/scan.py @@ -29,7 +29,7 @@ def scan(*outer_inputs): # Extract JAX scan inputs outer_inputs = list(outer_inputs) n_steps = outer_inputs[0] # JAX `length` - seqs = op.outer_seqs(outer_inputs) # JAX `xs` + seqs = [seq[:n_steps] for seq in op.outer_seqs(outer_inputs)] # JAX `xs` mit_sot_init = [] for tap, seq in zip( diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index 62e4a0608f..c75a4cf890 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -55,7 +55,7 @@ def range_arr(x): @numba_funcify.register(Scan) -def numba_funcify_Scan(op, node, **kwargs): +def numba_funcify_Scan(op: Scan, node, **kwargs): # Apply inner rewrites # TODO: Not sure this is the right place to do this, should we have a rewrite that # explicitly triggers the optimization of the inner graphs of Scan? @@ -67,9 +67,32 @@ def numba_funcify_Scan(op, node, **kwargs): .optimizer ) fgraph = op.fgraph + # When the buffer can only hold one SITSOT or as as many MITSOT as there are taps, + # We must always discard the oldest tap, so it's safe to destroy it in the inner function. + # TODO: Allow inplace for MITMOT + destroyable_sitsot = [ + inner_sitsot + for outer_sitsot, inner_sitsot in zip( + op.outer_sitsot(node.inputs), op.inner_sitsot(fgraph.inputs), strict=True + ) + if outer_sitsot.type.shape[0] == 1 + ] + destroyable_mitsot = [ + oldest_inner_mitmot + for outer_mitsot, oldest_inner_mitmot, taps in zip( + op.outer_mitsot(node.inputs), + op.oldest_inner_mitsot(fgraph.inputs), + op.info.mit_sot_in_slices, + strict=True, + ) + if outer_mitsot.type.shape[0] == abs(min(taps)) + ] + destroyable = {*destroyable_sitsot, *destroyable_mitsot} add_supervisor_to_fgraph( fgraph=fgraph, - input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs], + input_specs=[ + In(x, borrow=True, mutable=x in destroyable) for x in fgraph.inputs + ], accept_inplace=True, ) rewriter(fgraph) @@ -222,14 +245,16 @@ def add_output_storage_post_proc_stmt( # the storage array. # This is needed when the output storage array does not have a length # equal to the number of taps plus `n_steps`. + # If the storage size only allows one entry, there's nothing to rotate output_storage_post_proc_stmts.append( dedent( f""" - if (i + {tap_size}) > {storage_size}: + if 1 < {storage_size} < (i + {tap_size}): {outer_in_name}_shift = (i + {tap_size}) % ({storage_size}) - {outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift] - {outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:] - {outer_in_name} = np.concatenate(({outer_in_name}_right, {outer_in_name}_left)) + if {outer_in_name}_shift > 0: + {outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift] + {outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:] + {outer_in_name} = np.concatenate(({outer_in_name}_right, {outer_in_name}_left)) """ ).strip() ) @@ -417,4 +442,4 @@ def scan({", ".join(outer_in_names)}): scan_op_fn = compile_function_src(scan_op_src, "scan", {**globals(), **global_env}) - return numba_basic.numba_njit(scan_op_fn) + return numba_basic.numba_njit(scan_op_fn, boundscheck=False) diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 7e7e3b2cee..4f2739fb69 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -321,6 +321,16 @@ def inner_mitsot(self, list_inputs): self.info.n_seqs + n_mitmot_taps : self.info.n_seqs + ntaps_upto_sit_sot ] + def oldest_inner_mitsot(self, list_inputs): + inner_mitsot_inputs = self.inner_mitsot(list_inputs) + oldest_inner_mitsot_inputs = [] + offset = 0 + for taps in self.info.mit_sot_in_slices: + oldest_tap = np.argmin(taps) + oldest_inner_mitsot_inputs += [inner_mitsot_inputs[offset + oldest_tap]] + offset += len(taps) + return oldest_inner_mitsot_inputs + def outer_mitsot(self, list_inputs): offset = 1 + self.info.n_seqs + self.info.n_mit_mot return list_inputs[offset : offset + self.info.n_mit_sot] diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index 2ba282d8d6..3b74471cd4 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -3,7 +3,6 @@ import copy import dataclasses from itertools import chain -from sys import maxsize from typing import cast import numpy as np @@ -71,7 +70,7 @@ get_slice_elements, set_subtensor, ) -from pytensor.tensor.variable import TensorConstant +from pytensor.tensor.variable import TensorConstant, TensorVariable list_opt_slice = [ @@ -1183,8 +1182,7 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node): return subtensor_merge_replacements -@node_rewriter([Scan]) -def scan_save_mem(fgraph, node): +def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: bool): r"""Graph optimizer that reduces scan memory consumption. This optimizations attempts to determine if a `Scan` node, during its execution, @@ -1215,10 +1213,16 @@ def scan_save_mem(fgraph, node): The scan perform implementation takes the output sizes into consideration, saving the newest results over the oldest ones whenever the buffer is filled. - """ - if not isinstance(node.op, Scan): - return False + Paramaters + ---------- + backend_supports_output_pre_allocation: bool + When the backend supports output pre-allocation Scan must keep buffers + with a length of required_states + 1, because the inner function will + attempt to write the inner function outputs directly into the provided + position in the outer circular buffer. This would invalidate results, + if the input is still needed for some other output computation. + """ if hasattr(fgraph, "shape_feature"): shape_of = fgraph.shape_feature.shape_of else: @@ -1271,6 +1275,7 @@ def scan_save_mem(fgraph, node): # Note: For simplicity while Scans also have global_nsteps set to None. # All step optimizations require knowing the shape of the output, which # cannot be determined from the inputs alone. + global_nsteps: None | dict assert len(node.outputs) >= c_outs if len(node.outputs) == c_outs and not op.info.as_while: global_nsteps = {"real": -1, "sym": []} @@ -1278,7 +1283,7 @@ def scan_save_mem(fgraph, node): global_nsteps = None # Keeps track of the original slices that each client represent - slices = [None for o in node.outputs] + slices: list[None | list] = [None for o in node.outputs] # A list for each output indicating how many intermediate values # should be stored. If negative it means none of the intermediate @@ -1295,7 +1300,7 @@ def scan_save_mem(fgraph, node): # or not flag_store = False - # 2.2 Loop over the clients + # 2.2 Loop over the clients to figure out how many steps we actually need to do in the Scan for i, out in enumerate(node.outputs[:c_outs]): # look at all its clients slices[i] = [] @@ -1338,7 +1343,7 @@ def scan_save_mem(fgraph, node): except KeyError: length = out.shape[0] cf_slice = get_canonical_form_slice(this_slice[0], length) - slices[i] += [(cf_slice, this_slice)] + slices[i] += [(cf_slice, this_slice)] # type: ignore if isinstance(this_slice[0], slice) and this_slice[0].stop is None: global_nsteps = None @@ -1351,10 +1356,9 @@ def scan_save_mem(fgraph, node): get_scalar_constant_value(cf_slice[0], raise_not_constant=False) + 1 ) - if stop == maxsize or stop == get_scalar_constant_value( - length, raise_not_constant=False - ): + if stop == get_scalar_constant_value(length, raise_not_constant=False): stop = None + global_nsteps = None else: # there is a **gotcha** here ! Namely, scan returns an # array that contains the initial state of the output @@ -1366,21 +1370,13 @@ def scan_save_mem(fgraph, node): # initial state) stop = stop - init_l[i] - # 2.3.3 we might get away with less number of steps + # 2.3.3 we might get away with fewer steps if stop is not None and global_nsteps is not None: # yes if it is a tensor if isinstance(stop, Variable): global_nsteps["sym"] += [stop] - # not if it is maxsize - elif isinstance(stop, int) and stop == maxsize: - global_nsteps = None - # yes if it is a int k, 0 < k < maxsize - elif isinstance(stop, int) and global_nsteps["real"] < stop: - global_nsteps["real"] = stop - # yes if it is a int k, 0 < k < maxsize - elif isinstance(stop, int) and stop > 0: - pass - # not otherwise + elif isinstance(stop, int | np.integer): + global_nsteps["real"] = max(global_nsteps["real"], stop) else: global_nsteps = None @@ -1430,9 +1426,18 @@ def scan_save_mem(fgraph, node): store_steps[i] = 0 break - if isinstance(this_slice[0], slice) and this_slice[0].start is None: - store_steps[i] = 0 - break + if isinstance(this_slice[0], slice): + start = this_slice[0].start + if isinstance(start, Constant): + start = start.data + # Don't do anything if the subtensor is starting from the beginning of the buffer + # Or just skipping the initial values (default output returned to the user). + # Trimming the initial values would require a roll to align the buffer once scan is done + # As it always starts writing at position [0+max(taps)], and ends up at position [:max(taps)] + # It's cheaper to just keep the initial values in the buffer and slice them away (default output) + if start in (0, None, init_l[i]): + store_steps[i] = 0 + break # Special case for recurrent outputs where only the last result # is requested. This is needed for this rewrite to apply to @@ -1478,7 +1483,10 @@ def scan_save_mem(fgraph, node): # for mitsots and sitsots (because mitmots are not # currently supported by the mechanism) and only if # the pre-allocation mechanism is activated. - prealloc_outs = config.scan__allow_output_prealloc + prealloc_outs = ( + backend_supports_output_pre_allocation + and config.scan__allow_output_prealloc + ) first_mitsot_idx = op_info.n_mit_mot last_sitsot_idx = ( @@ -1487,6 +1495,8 @@ def scan_save_mem(fgraph, node): preallocable_output = first_mitsot_idx <= i <= last_sitsot_idx if prealloc_outs and preallocable_output: + # TODO: If there's only one output or other outputs do not depend + # on the same input, we could reduce the buffer size to the minimum pval = select_max(nw_steps - start + init_l[i], init_l[i] + 1) else: pval = select_max(nw_steps - start + init_l[i], init_l[i]) @@ -1653,7 +1663,7 @@ def scan_save_mem(fgraph, node): name=op.name, allow_gc=op.allow_gc, ) - new_outs = new_op(*node_ins, return_list=True) + new_outs = cast(list[TensorVariable], new_op(*node_ins, return_list=True)) old_new = [] # 3.7 Get replace pairs for those outputs that do not change @@ -1683,7 +1693,7 @@ def scan_save_mem(fgraph, node): sl_ins = get_slice_elements( nw_slice, lambda entry: isinstance(entry, Variable) ) - new_o = subtens(new_outs[nw_pos], *sl_ins) + new_o = cast(TensorVariable, subtens(new_outs[nw_pos], *sl_ins)) if new_o.ndim > 0: new_o = new_o[:: cnf_slice[1]] replaced_outs.append(idx) @@ -1703,10 +1713,7 @@ def scan_save_mem(fgraph, node): - init_l[pos] + store_steps[pos] ) - if ( - cnf_slice[0].stop is not None - and cnf_slice[0].stop != maxsize - ): + if cnf_slice[0].stop is not None: stop = ( cnf_slice[0].stop - nw_steps @@ -1741,7 +1748,7 @@ def scan_save_mem(fgraph, node): sl_ins = get_slice_elements( nw_slice, lambda entry: isinstance(entry, Variable) ) - new_o = subtens(new_outs[nw_pos], *sl_ins) + new_o = cast(TensorVariable, subtens(new_outs[nw_pos], *sl_ins)) if new_o.ndim > 0: new_o = new_o[:: cnf_slice[1]] old_new += [(old, new_o)] @@ -1772,6 +1779,20 @@ def scan_save_mem(fgraph, node): return False +@node_rewriter([Scan]) +def scan_save_mem_prealloc(fgraph, node): + return scan_save_mem_rewrite( + fgraph, node, backend_supports_output_pre_allocation=True + ) + + +@node_rewriter([Scan]) +def scan_save_mem_no_prealloc(fgraph, node): + return scan_save_mem_rewrite( + fgraph, node, backend_supports_output_pre_allocation=False + ) + + class ScanMerge(GraphRewriter): r"""Graph optimizer that merges different scan ops. @@ -2499,10 +2520,20 @@ def scan_push_out_dot1(fgraph, node): optdb.register("scan_eqopt2", scan_eqopt2, "fast_run", "scan", position=1.6) # ScanSaveMem should execute only once per node. optdb.register( - "scan_save_mem", - in2out(scan_save_mem, ignore_newtrees=True), + "scan_save_mem_prealloc", + in2out(scan_save_mem_prealloc, ignore_newtrees=True), "fast_run", "scan", + "scan_save_mem", + position=1.61, +) +optdb.register( + "scan_save_mem_no_prealloc", + in2out(scan_save_mem_no_prealloc, ignore_newtrees=True), + "numba", + "jax", + "pytorch", + use_db_name_as_tag=False, position=1.61, ) optdb.register( diff --git a/pytensor/scan/utils.py b/pytensor/scan/utils.py index 611012b97e..6a0cdde461 100644 --- a/pytensor/scan/utils.py +++ b/pytensor/scan/utils.py @@ -231,8 +231,8 @@ def expand_empty(tensor_var, size): if size == 0: return tensor_var - shapes = [tensor_var.shape[x] for x in range(tensor_var.ndim)] - new_shape = [size + shapes[0]] + shapes[1:] + shapes = tuple(tensor_var.shape) + new_shape = (size + shapes[0], *shapes[1:]) empty = AllocEmpty(tensor_var.dtype)(*new_shape) ret = set_subtensor(empty[: shapes[0]], tensor_var) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index e30887cfe3..d33fd5a521 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -2943,6 +2943,8 @@ def stack(tensors: Sequence["TensorLike"], axis: int = 0): ): # In case there is direct scalar tensors = list(map(as_tensor_variable, tensors)) + if len(tensors) == 1: + return atleast_1d(tensors[0]) dtype = ps.upcast(*[i.dtype for i in tensors]) return MakeVector(dtype)(*tensors) return join(axis, *[shape_padaxis(t, axis) for t in tensors]) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 3de4f41068..8e3e5cb902 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -1,7 +1,7 @@ import logging import sys import warnings -from collections.abc import Callable, Iterable +from collections.abc import Callable, Iterable, Sequence from itertools import chain, groupby from textwrap import dedent from typing import cast, overload @@ -33,7 +33,9 @@ alloc, get_scalar_constant_value, nonzero, - scalar_from_tensor, +) +from pytensor.tensor.basic import ( + constant as tensor_constant, ) from pytensor.tensor.blockwise import vectorize_node_fallback from pytensor.tensor.elemwise import DimShuffle @@ -256,20 +258,20 @@ def get_idx_list(inputs, idx_list): def get_canonical_form_slice( theslice: slice, length: int | np.integer | ScalarVariable | TensorVariable, -) -> tuple[slice, int | ScalarConstant]: ... +) -> tuple[slice, int | TensorVariable]: ... @overload def get_canonical_form_slice( theslice: int | np.integer | ScalarVariable | TensorVariable, length: int | np.integer | ScalarVariable | TensorVariable, -) -> tuple[ScalarVariable, int]: ... +) -> tuple[TensorVariable, int]: ... def get_canonical_form_slice( theslice: slice | int | np.integer | ScalarVariable | TensorVariable, length: int | np.integer | ScalarVariable | TensorVariable, -) -> tuple[slice | ScalarVariable, int | ScalarConstant]: +) -> tuple[slice | TensorVariable, int | TensorVariable]: """Convert indices or slices to canonical form. Scalar integer indices or python Slices with Scalar/None attributes @@ -296,30 +298,56 @@ def get_canonical_form_slice( """ from pytensor.tensor import ge, lt, sign, switch - # Other non-slice types are the scalar indexing case - if not isinstance(theslice, slice): - if isinstance(theslice, int | np.integer | ScalarVariable) or ( - isinstance(theslice, TensorVariable) and theslice.ndim == 0 - ): - cano = switch(lt(theslice, 0), (theslice + length), theslice) - return scalar_from_tensor(cano), 1 - raise ValueError(f"Slice {theslice} is not a supported slice type.") + def undo_scalarization(x): + """Undo scalarization of a variable. - # At this point we have a slice object. Possibly with symbolic inputs. + PyTensor Basic index operations use ScalarVariables for the indices/slice arguments. + But reasoning symbolically about the result of multiple indexing operations, we usually + want to work on TensorVariables, since rewrites work on those and not ScalarVariables. + + This function undoes ScalarFromTensor operation or converts ScalarConstants to TensorConstants. + """ + if isinstance(x, ScalarVariable): + if isinstance(x, ScalarConstant): + return tensor_constant(x.data, dtype=x.dtype) + elif x.owner is not None and isinstance(x.owner.op, ScalarFromTensor): + return x.owner.inputs[0] + else: + return as_tensor_variable(x) + return x def analyze(x): try: x_constant = as_index_literal(x) is_constant = True except NotScalarConstantError: - x_constant = x + x_constant = undo_scalarization(x) is_constant = False return x_constant, is_constant + length, is_length_constant = analyze(length) + + # Other non-slice types are the scalar indexing case + if not isinstance(theslice, slice): + if not ( + isinstance(theslice, int | np.integer | ScalarVariable) + or (isinstance(theslice, TensorVariable) and theslice.ndim == 0) + ): + raise ValueError(f"Slice {theslice} is not a supported slice type.") + + idx, is_index_constant = analyze(theslice) + if is_index_constant: + if idx >= 0: + return idx, 1 + else: + return idx + length, 1 + else: + return switch(lt(idx, 0), idx + length, idx), 1 + + # At this point we have a slice object. Possibly with symbolic inputs. start, is_start_constant = analyze(theslice.start) stop, is_stop_constant = analyze(theslice.stop) step, is_step_constant = analyze(theslice.step) - length, is_length_constant = analyze(length) if ( is_start_constant @@ -645,7 +673,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False): def get_slice_elements( - idxs: list, + idxs: Sequence, cond: Callable = lambda x: isinstance(x, Variable), ) -> list: """Extract slice elements conditional on a given predicate function. diff --git a/tests/link/numba/test_scan.py b/tests/link/numba/test_scan.py index 037155880e..8c0d9d4f52 100644 --- a/tests/link/numba/test_scan.py +++ b/tests/link/numba/test_scan.py @@ -339,39 +339,6 @@ def power_step(prior_result, x): compare_numba_and_py([A], result, test_input_vals) -@pytest.mark.parametrize("n_steps_val", [1, 5]) -def test_scan_save_mem_basic(n_steps_val): - """Make sure we can handle storage changes caused by the `scan_save_mem` rewrite.""" - - def f_pow2(x_tm2, x_tm1): - return 2 * x_tm1 + x_tm2 - - init_x = pt.dvector("init_x") - n_steps = pt.iscalar("n_steps") - output, _ = scan( - f_pow2, - sequences=[], - outputs_info=[{"initial": init_x, "taps": [-2, -1]}], - non_sequences=[], - n_steps=n_steps, - ) - - state_val = np.array([1.0, 2.0]) - - numba_mode = get_mode("NUMBA").including("scan_save_mem") - py_mode = Mode("py").including("scan_save_mem") - - test_input_vals = (state_val, n_steps_val) - - compare_numba_and_py( - [init_x, n_steps], - [output], - test_input_vals, - numba_mode=numba_mode, - py_mode=py_mode, - ) - - def test_grad_sitsot(): def get_sum_of_grad(inp): scan_outputs, updates = scan( @@ -482,3 +449,193 @@ def step(seq1, seq2, mitsot1, mitsot2, sitsot1): np.testing.assert_array_almost_equal(numba_r, ref_r) benchmark(numba_fn, *test.values()) + + +@pytest.mark.parametrize("n_steps_constant", (True, False)) +def test_inplace_taps(n_steps_constant): + """Test that numba will inplace in the inner_function of the oldest sit-sot, mit-sot taps.""" + n_steps = 10 if n_steps_constant else scalar("n_steps", dtype=int) + a = scalar("a") + x0 = scalar("x0") + y0 = vector("y0", shape=(2,)) + z0 = vector("z0", shape=(3,)) + + def step(ztm3, ztm1, xtm1, ytm1, ytm2, a): + z = ztm1 + 1 + ztm3 + a + x = xtm1 + 1 + y = ytm1 + 1 + ytm2 + a + return z, x, z + x + y, y + + [zs, xs, ws, ys], _ = scan( + fn=step, + outputs_info=[ + dict(initial=z0, taps=[-3, -1]), + dict(initial=x0, taps=[-1]), + None, + dict(initial=y0, taps=[-1, -2]), + ], + non_sequences=[a], + n_steps=n_steps, + ) + numba_fn, _ = compare_numba_and_py( + [n_steps] * (not n_steps_constant) + [a, x0, y0, z0], + [zs[-1], xs[-1], ws[-1], ys[-1]], + [10] * (not n_steps_constant) + [np.pi, np.e, [1, np.euler_gamma], [0, 1, 2]], + numba_mode="NUMBA", + eval_obj_mode=False, + ) + [scan_op] = [ + node.op + for node in numba_fn.maker.fgraph.toposort() + if isinstance(node.op, Scan) + ] + + # Scan reorders inputs internally, so we need to check its ordering + inner_inps = scan_op.fgraph.inputs + mit_sot_inps = scan_op.inner_mitsot(inner_inps) + oldest_mit_sot_inps = [ + # Implicitly assume that the first mit-sot input is the one with 3 taps + # This is not a required behavior and the test can change if we need to change Scan. + mit_sot_inps[:2][scan_op.info.mit_sot_in_slices[0].index(-3)], + mit_sot_inps[2:][scan_op.info.mit_sot_in_slices[1].index(-2)], + ] + [sit_sot_inp] = scan_op.inner_sitsot(inner_inps) + + inner_outs = scan_op.fgraph.outputs + mit_sot_outs = scan_op.inner_mitsot_outs(inner_outs) + [sit_sot_out] = scan_op.inner_sitsot_outs(inner_outs) + [nit_sot_out] = scan_op.inner_nitsot_outs(inner_outs) + + if n_steps_constant: + assert mit_sot_outs[0].owner.op.destroy_map == { + 0: [mit_sot_outs[0].owner.inputs.index(oldest_mit_sot_inps[0])] + } + assert mit_sot_outs[1].owner.op.destroy_map == { + 0: [mit_sot_outs[1].owner.inputs.index(oldest_mit_sot_inps[1])] + } + assert sit_sot_out.owner.op.destroy_map == { + 0: [sit_sot_out.owner.inputs.index(sit_sot_inp)] + } + else: + # This is not a feature, but a current limitation + # https://github.com/pymc-devs/pytensor/issues/1283 + assert mit_sot_outs[0].owner.op.destroy_map == {} + assert mit_sot_outs[1].owner.op.destroy_map == {} + assert sit_sot_out.owner.op.destroy_map == {} + assert nit_sot_out.owner.op.destroy_map == {} + + +@pytest.mark.parametrize( + "buffer_size", ("unit", "aligned", "misaligned", "whole", "whole+init") +) +@pytest.mark.parametrize("n_steps, op_size", [(10, 2), (512, 2), (512, 256)]) +class TestScanSITSOTBuffer: + def buffer_tester(self, n_steps, op_size, buffer_size, benchmark=None): + x0 = pt.vector(shape=(op_size,), dtype="float64") + xs, _ = pytensor.scan( + fn=lambda xtm1: (xtm1 + 1), + outputs_info=[x0], + n_steps=n_steps - 1, # 1- makes it easier to align/misalign + ) + if buffer_size == "unit": + xs_kept = xs[-1] # Only last state is used + expected_buffer_size = 1 + elif buffer_size == "aligned": + xs_kept = xs[-2:] # The buffer will be aligned at the end of the 9 steps + expected_buffer_size = 2 + elif buffer_size == "misaligned": + xs_kept = xs[-3:] # The buffer will be misaligned at the end of the 9 steps + expected_buffer_size = 3 + elif buffer_size == "whole": + xs_kept = xs # What users think is the whole buffer + expected_buffer_size = n_steps + elif buffer_size == "whole+init": + xs_kept = xs.owner.inputs[0] # Whole buffer actually used by Scan + expected_buffer_size = n_steps + + x_test = np.zeros(x0.type.shape) + numba_fn, _ = compare_numba_and_py( + [x0], + [xs_kept], + test_inputs=[x_test], + numba_mode="NUMBA", # Default doesn't include optimizations + eval_obj_mode=False, + ) + [scan_node] = [ + node + for node in numba_fn.maker.fgraph.toposort() + if isinstance(node.op, Scan) + ] + buffer = scan_node.inputs[1] + assert buffer.type.shape[0] == expected_buffer_size + + if benchmark is not None: + numba_fn.trust_input = True + benchmark(numba_fn, x_test) + + def test_sit_sot_buffer(self, n_steps, op_size, buffer_size): + self.buffer_tester(n_steps, op_size, buffer_size, benchmark=None) + + def test_sit_sot_buffer_benchmark(self, n_steps, op_size, buffer_size, benchmark): + self.buffer_tester(n_steps, op_size, buffer_size, benchmark=benchmark) + + +@pytest.mark.parametrize("constant_n_steps", [False, True]) +@pytest.mark.parametrize("n_steps_val", [1, 1000]) +class TestScanMITSOTBuffer: + def buffer_tester(self, constant_n_steps, n_steps_val, benchmark=None): + """Make sure we can handle storage changes caused by the `scan_save_mem` rewrite.""" + + def f_pow2(x_tm2, x_tm1): + return 2 * x_tm1 + x_tm2 + + init_x = pt.vector("init_x", shape=(2,)) + n_steps = pt.iscalar("n_steps") + output, _ = scan( + f_pow2, + sequences=[], + outputs_info=[{"initial": init_x, "taps": [-2, -1]}], + non_sequences=[], + n_steps=n_steps_val if constant_n_steps else n_steps, + ) + + init_x_val = np.array([1.0, 2.0], dtype=init_x.type.dtype) + test_vals = ( + [init_x_val] + if constant_n_steps + else [init_x_val, np.asarray(n_steps_val, dtype=n_steps.type.dtype)] + ) + numba_fn, _ = compare_numba_and_py( + [init_x] if constant_n_steps else [init_x, n_steps], + [output[-1]], + test_vals, + numba_mode="NUMBA", + eval_obj_mode=False, + ) + + if n_steps_val == 1 and constant_n_steps: + # There's no Scan in the graph when nsteps=constant(1) + return + + # Check the buffer size as been optimized + [scan_node] = [ + node + for node in numba_fn.maker.fgraph.toposort() + if isinstance(node.op, Scan) + ] + [mitsot_buffer] = scan_node.op.outer_mitsot(scan_node.inputs) + mitsot_buffer_shape = mitsot_buffer.shape.eval( + {init_x: init_x_val, n_steps: n_steps_val}, + accept_inplace=True, + on_unused_input="ignore", + ) + assert tuple(mitsot_buffer_shape) == (2,) + if benchmark is not None: + numba_fn.trust_input = True + benchmark(numba_fn, *test_vals) + + def test_mit_sot_buffer(self, constant_n_steps, n_steps_val): + self.buffer_tester(constant_n_steps, n_steps_val, benchmark=None) + + def test_mit_sot_buffer_benchmark(self, constant_n_steps, n_steps_val, benchmark): + self.buffer_tester(constant_n_steps, n_steps_val, benchmark=benchmark) diff --git a/tests/scan/test_printing.py b/tests/scan/test_printing.py index 9bf32af48f..70c781a0c9 100644 --- a/tests/scan/test_printing.py +++ b/tests/scan/test_printing.py @@ -44,25 +44,24 @@ def test_debugprint_sitsot(): │ │ │ │ │ │ └─ 1.0 [id O] │ │ │ │ │ └─ 0 [id P] │ │ │ │ └─ Subtensor{i} [id Q] - │ │ │ │ ├─ Shape [id R] - │ │ │ │ │ └─ Unbroadcast{0} [id J] - │ │ │ │ │ └─ ··· - │ │ │ │ └─ 1 [id S] + │ │ │ │ ├─ Shape [id I] + │ │ │ │ │ └─ ··· + │ │ │ │ └─ 1 [id R] │ │ │ ├─ Unbroadcast{0} [id J] │ │ │ │ └─ ··· - │ │ │ └─ ScalarFromTensor [id T] + │ │ │ └─ ScalarFromTensor [id S] │ │ │ └─ Subtensor{i} [id H] │ │ │ └─ ··· │ │ └─ A [id M] (outer_in_non_seqs-0) - │ └─ 1 [id U] - └─ -1 [id V] + │ └─ 1 [id T] + └─ -1 [id U] Inner graphs: Scan{scan_fn, while_loop=False, inplace=none} [id C] - ← Mul [id W] (inner_out_sit_sot-0) - ├─ *0- [id X] -> [id E] (inner_in_sit_sot-0) - └─ *1- [id Y] -> [id M] (inner_in_non_seqs-0) + ← Mul [id V] (inner_out_sit_sot-0) + ├─ *0- [id W] -> [id E] (inner_in_sit_sot-0) + └─ *1- [id X] -> [id M] (inner_in_non_seqs-0) """ for truth, out in zip(expected_output.split("\n"), lines, strict=True): @@ -103,25 +102,24 @@ def test_debugprint_sitsot_no_extra_info(): │ │ │ │ │ │ └─ 1.0 [id O] │ │ │ │ │ └─ 0 [id P] │ │ │ │ └─ Subtensor{i} [id Q] - │ │ │ │ ├─ Shape [id R] - │ │ │ │ │ └─ Unbroadcast{0} [id J] - │ │ │ │ │ └─ ··· - │ │ │ │ └─ 1 [id S] + │ │ │ │ ├─ Shape [id I] + │ │ │ │ │ └─ ··· + │ │ │ │ └─ 1 [id R] │ │ │ ├─ Unbroadcast{0} [id J] │ │ │ │ └─ ··· - │ │ │ └─ ScalarFromTensor [id T] + │ │ │ └─ ScalarFromTensor [id S] │ │ │ └─ Subtensor{i} [id H] │ │ │ └─ ··· │ │ └─ A [id M] - │ └─ 1 [id U] - └─ -1 [id V] + │ └─ 1 [id T] + └─ -1 [id U] Inner graphs: Scan{scan_fn, while_loop=False, inplace=none} [id C] - ← Mul [id W] - ├─ *0- [id X] -> [id E] - └─ *1- [id Y] -> [id M] + ← Mul [id V] + ├─ *0- [id W] -> [id E] + └─ *1- [id X] -> [id M] """ for truth, out in zip(expected_output.split("\n"), lines, strict=True): @@ -288,25 +286,24 @@ def compute_A_k(A, k): │ │ │ │ │ │ │ └─ 1.0 [id BQ] │ │ │ │ │ │ └─ 0 [id BR] │ │ │ │ │ └─ Subtensor{i} [id BS] - │ │ │ │ │ ├─ Shape [id BT] - │ │ │ │ │ │ └─ Unbroadcast{0} [id BL] - │ │ │ │ │ │ └─ ··· - │ │ │ │ │ └─ 1 [id BU] + │ │ │ │ │ ├─ Shape [id BK] + │ │ │ │ │ │ └─ ··· + │ │ │ │ │ └─ 1 [id BT] │ │ │ │ ├─ Unbroadcast{0} [id BL] │ │ │ │ │ └─ ··· - │ │ │ │ └─ ScalarFromTensor [id BV] + │ │ │ │ └─ ScalarFromTensor [id BU] │ │ │ │ └─ Subtensor{i} [id BJ] │ │ │ │ └─ ··· │ │ │ └─ *2- [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0) - │ │ └─ 1 [id BW] - │ └─ -1 [id BX] - └─ ExpandDims{axis=0} [id BY] - └─ *1- [id BZ] -> [id U] (inner_in_seqs-1) + │ │ └─ 1 [id BV] + │ └─ -1 [id BW] + └─ ExpandDims{axis=0} [id BX] + └─ *1- [id BY] -> [id U] (inner_in_seqs-1) Scan{scan_fn, while_loop=False, inplace=none} [id BE] - ← Mul [id CA] (inner_out_sit_sot-0) - ├─ *0- [id CB] -> [id BG] (inner_in_sit_sot-0) - └─ *1- [id CC] -> [id BO] (inner_in_non_seqs-0) + ← Mul [id BZ] (inner_out_sit_sot-0) + ├─ *0- [id CA] -> [id BG] (inner_in_sit_sot-0) + └─ *1- [id CB] -> [id BO] (inner_in_non_seqs-0) """ for truth, out in zip(expected_output.split("\n"), lines, strict=True): @@ -386,27 +383,26 @@ def compute_A_k(A, k): │ │ │ │ │ │ │ └─ 1.0 [id BR] │ │ │ │ │ │ └─ 0 [id BS] │ │ │ │ │ └─ Subtensor{i} [id BT] - │ │ │ │ │ ├─ Shape [id BU] - │ │ │ │ │ │ └─ Unbroadcast{0} [id BN] - │ │ │ │ │ │ └─ ··· - │ │ │ │ │ └─ 1 [id BV] + │ │ │ │ │ ├─ Shape [id BM] + │ │ │ │ │ │ └─ ··· + │ │ │ │ │ └─ 1 [id BU] │ │ │ │ ├─ Unbroadcast{0} [id BN] │ │ │ │ │ └─ ··· - │ │ │ │ └─ ScalarFromTensor [id BW] + │ │ │ │ └─ ScalarFromTensor [id BV] │ │ │ │ └─ Subtensor{i} [id BL] │ │ │ │ └─ ··· │ │ │ └─ *2- [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0) - │ │ └─ 1 [id BX] - │ └─ -1 [id BY] - └─ ExpandDims{axis=0} [id BZ] + │ │ └─ 1 [id BW] + │ └─ -1 [id BX] + └─ ExpandDims{axis=0} [id BY] └─ *1- [id Z] (inner_in_seqs-1) Scan{scan_fn, while_loop=False, inplace=none} [id BH] - → *0- [id CA] -> [id BI] (inner_in_sit_sot-0) - → *1- [id CB] -> [id BA] (inner_in_non_seqs-0) - ← Mul [id CC] (inner_out_sit_sot-0) - ├─ *0- [id CA] (inner_in_sit_sot-0) - └─ *1- [id CB] (inner_in_non_seqs-0) + → *0- [id BZ] -> [id BI] (inner_in_sit_sot-0) + → *1- [id CA] -> [id BA] (inner_in_non_seqs-0) + ← Mul [id CB] (inner_out_sit_sot-0) + ├─ *0- [id BZ] (inner_in_sit_sot-0) + └─ *1- [id CA] (inner_in_non_seqs-0) """ for truth, out in zip(expected_output.split("\n"), lines, strict=True): @@ -528,98 +524,97 @@ def test_debugprint_mitmot(): │ │ │ │ │ │ │ │ └─ 1.0 [id R] │ │ │ │ │ │ │ └─ 0 [id S] │ │ │ │ │ │ └─ Subtensor{i} [id T] - │ │ │ │ │ │ ├─ Shape [id U] - │ │ │ │ │ │ │ └─ Unbroadcast{0} [id M] - │ │ │ │ │ │ │ └─ ··· - │ │ │ │ │ │ └─ 1 [id V] + │ │ │ │ │ │ ├─ Shape [id L] + │ │ │ │ │ │ │ └─ ··· + │ │ │ │ │ │ └─ 1 [id U] │ │ │ │ │ ├─ Unbroadcast{0} [id M] │ │ │ │ │ │ └─ ··· - │ │ │ │ │ └─ ScalarFromTensor [id W] + │ │ │ │ │ └─ ScalarFromTensor [id V] │ │ │ │ │ └─ Subtensor{i} [id K] │ │ │ │ │ └─ ··· │ │ │ │ └─ A [id P] (outer_in_non_seqs-0) - │ │ │ └─ 0 [id X] - │ │ └─ 1 [id Y] - │ ├─ Subtensor{:stop} [id Z] (outer_in_seqs-0) - │ │ ├─ Subtensor{::step} [id BA] - │ │ │ ├─ Subtensor{:stop} [id BB] + │ │ │ └─ 0 [id W] + │ │ └─ 1 [id X] + │ ├─ Subtensor{:stop} [id Y] (outer_in_seqs-0) + │ │ ├─ Subtensor{::step} [id Z] + │ │ │ ├─ Subtensor{:stop} [id BA] │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0) │ │ │ │ │ └─ ··· - │ │ │ │ └─ -1 [id BC] - │ │ │ └─ -1 [id BD] - │ │ └─ ScalarFromTensor [id BE] + │ │ │ │ └─ -1 [id BB] + │ │ │ └─ -1 [id BC] + │ │ └─ ScalarFromTensor [id BD] │ │ └─ Sub [id C] │ │ └─ ··· - │ ├─ Subtensor{:stop} [id BF] (outer_in_seqs-1) - │ │ ├─ Subtensor{:stop} [id BG] - │ │ │ ├─ Subtensor{::step} [id BH] + │ ├─ Subtensor{:stop} [id BE] (outer_in_seqs-1) + │ │ ├─ Subtensor{:stop} [id BF] + │ │ │ ├─ Subtensor{::step} [id BG] │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0) │ │ │ │ │ └─ ··· - │ │ │ │ └─ -1 [id BI] - │ │ │ └─ -1 [id BJ] - │ │ └─ ScalarFromTensor [id BK] + │ │ │ │ └─ -1 [id BH] + │ │ │ └─ -1 [id BI] + │ │ └─ ScalarFromTensor [id BJ] │ │ └─ Sub [id C] │ │ └─ ··· - │ ├─ Subtensor{::step} [id BL] (outer_in_mit_mot-0) - │ │ ├─ IncSubtensor{start:} [id BM] - │ │ │ ├─ Second [id BN] + │ ├─ Subtensor{::step} [id BK] (outer_in_mit_mot-0) + │ │ ├─ IncSubtensor{start:} [id BL] + │ │ │ ├─ Second [id BM] │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0) │ │ │ │ │ └─ ··· - │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BO] - │ │ │ │ └─ 0.0 [id BP] - │ │ │ ├─ IncSubtensor{i} [id BQ] - │ │ │ │ ├─ Second [id BR] - │ │ │ │ │ ├─ Subtensor{start:} [id BS] + │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BN] + │ │ │ │ └─ 0.0 [id BO] + │ │ │ ├─ IncSubtensor{i} [id BP] + │ │ │ │ ├─ Second [id BQ] + │ │ │ │ │ ├─ Subtensor{start:} [id BR] │ │ │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0) │ │ │ │ │ │ │ └─ ··· - │ │ │ │ │ │ └─ 1 [id BT] - │ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BU] - │ │ │ │ │ └─ 0.0 [id BV] - │ │ │ │ ├─ Second [id BW] - │ │ │ │ │ ├─ Subtensor{i} [id BX] - │ │ │ │ │ │ ├─ Subtensor{start:} [id BS] + │ │ │ │ │ │ └─ 1 [id BS] + │ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BT] + │ │ │ │ │ └─ 0.0 [id BU] + │ │ │ │ ├─ Second [id BV] + │ │ │ │ │ ├─ Subtensor{i} [id BW] + │ │ │ │ │ │ ├─ Subtensor{start:} [id BR] │ │ │ │ │ │ │ └─ ··· - │ │ │ │ │ │ └─ -1 [id BY] - │ │ │ │ │ └─ ExpandDims{axis=0} [id BZ] - │ │ │ │ │ └─ Second [id CA] - │ │ │ │ │ ├─ Sum{axes=None} [id CB] - │ │ │ │ │ │ └─ Subtensor{i} [id BX] + │ │ │ │ │ │ └─ -1 [id BX] + │ │ │ │ │ └─ ExpandDims{axis=0} [id BY] + │ │ │ │ │ └─ Second [id BZ] + │ │ │ │ │ ├─ Sum{axes=None} [id CA] + │ │ │ │ │ │ └─ Subtensor{i} [id BW] │ │ │ │ │ │ └─ ··· - │ │ │ │ │ └─ 1.0 [id CC] - │ │ │ │ └─ -1 [id BY] - │ │ │ └─ 1 [id BT] - │ │ └─ -1 [id CD] - │ ├─ Alloc [id CE] (outer_in_sit_sot-0) - │ │ ├─ 0.0 [id CF] - │ │ ├─ Add [id CG] + │ │ │ │ │ └─ 1.0 [id CB] + │ │ │ │ └─ -1 [id BX] + │ │ │ └─ 1 [id BS] + │ │ └─ -1 [id CC] + │ ├─ Alloc [id CD] (outer_in_sit_sot-0) + │ │ ├─ 0.0 [id CE] + │ │ ├─ Add [id CF] │ │ │ ├─ Sub [id C] │ │ │ │ └─ ··· - │ │ │ └─ 1 [id CH] - │ │ └─ Subtensor{i} [id CI] - │ │ ├─ Shape [id CJ] + │ │ │ └─ 1 [id CG] + │ │ └─ Subtensor{i} [id CH] + │ │ ├─ Shape [id CI] │ │ │ └─ A [id P] - │ │ └─ 0 [id CK] + │ │ └─ 0 [id CJ] │ └─ A [id P] (outer_in_non_seqs-0) - └─ -1 [id CL] + └─ -1 [id CK] Inner graphs: Scan{grad_of_scan_fn, while_loop=False, inplace=none} [id B] - ← Add [id CM] (inner_out_mit_mot-0-0) - ├─ Mul [id CN] - │ ├─ *2- [id CO] -> [id BL] (inner_in_mit_mot-0-0) - │ └─ *5- [id CP] -> [id P] (inner_in_non_seqs-0) - └─ *3- [id CQ] -> [id BL] (inner_in_mit_mot-0-1) - ← Add [id CR] (inner_out_sit_sot-0) - ├─ Mul [id CS] - │ ├─ *2- [id CO] -> [id BL] (inner_in_mit_mot-0-0) - │ └─ *0- [id CT] -> [id Z] (inner_in_seqs-0) - └─ *4- [id CU] -> [id CE] (inner_in_sit_sot-0) + ← Add [id CL] (inner_out_mit_mot-0-0) + ├─ Mul [id CM] + │ ├─ *2- [id CN] -> [id BK] (inner_in_mit_mot-0-0) + │ └─ *5- [id CO] -> [id P] (inner_in_non_seqs-0) + └─ *3- [id CP] -> [id BK] (inner_in_mit_mot-0-1) + ← Add [id CQ] (inner_out_sit_sot-0) + ├─ Mul [id CR] + │ ├─ *2- [id CN] -> [id BK] (inner_in_mit_mot-0-0) + │ └─ *0- [id CS] -> [id Y] (inner_in_seqs-0) + └─ *4- [id CT] -> [id CD] (inner_in_sit_sot-0) Scan{scan_fn, while_loop=False, inplace=none} [id F] - ← Mul [id CV] (inner_out_sit_sot-0) - ├─ *0- [id CT] -> [id H] (inner_in_sit_sot-0) - └─ *1- [id CW] -> [id P] (inner_in_non_seqs-0) + ← Mul [id CU] (inner_out_sit_sot-0) + ├─ *0- [id CS] -> [id H] (inner_in_sit_sot-0) + └─ *1- [id CV] -> [id P] (inner_in_non_seqs-0) """ for truth, out in zip(expected_output.split("\n"), lines, strict=True): @@ -648,35 +643,37 @@ def no_shared_fn(n, x_tm1, M): # (i.e. from `Scan._fn`) out = pytensor.function([M], out, updates=updates, mode="FAST_RUN") - expected_output = """Scan{scan_fn, while_loop=False, inplace=all} [id A] 2 (outer_out_sit_sot-0) - ├─ 20000 [id B] (n_steps) - ├─ [ 0 ... 998 19999] [id C] (outer_in_seqs-0) - ├─ SetSubtensor{:stop} [id D] 1 (outer_in_sit_sot-0) - │ ├─ AllocEmpty{dtype='int64'} [id E] 0 - │ │ └─ 20000 [id B] - │ ├─ [0] [id F] - │ └─ 1 [id G] - └─ [id H] (outer_in_non_seqs-0) - - Inner graphs: - - Scan{scan_fn, while_loop=False, inplace=all} [id A] - ← Composite{switch(lt(0, i0), 1, 0)} [id I] (inner_out_sit_sot-0) - └─ Subtensor{i, j, k} [id J] - ├─ *2- [id K] -> [id H] (inner_in_non_seqs-0) - ├─ ScalarFromTensor [id L] - │ └─ *0- [id M] -> [id C] (inner_in_seqs-0) - ├─ ScalarFromTensor [id N] - │ └─ *1- [id O] -> [id D] (inner_in_sit_sot-0) - └─ 0 [id P] - - Composite{switch(lt(0, i0), 1, 0)} [id I] - ← Switch [id Q] 'o0' - ├─ LT [id R] - │ ├─ 0 [id S] - │ └─ i0 [id T] - ├─ 1 [id U] - └─ 0 [id S] + expected_output = """Subtensor{start:} [id A] 3 + ├─ Scan{scan_fn, while_loop=False, inplace=all} [id B] 2 (outer_out_sit_sot-0) + │ ├─ 20000 [id C] (n_steps) + │ ├─ [ 0 ... 998 19999] [id D] (outer_in_seqs-0) + │ ├─ SetSubtensor{:stop} [id E] 1 (outer_in_sit_sot-0) + │ │ ├─ AllocEmpty{dtype='int64'} [id F] 0 + │ │ │ └─ 20001 [id G] + │ │ ├─ [0] [id H] + │ │ └─ 1 [id I] + │ └─ [id J] (outer_in_non_seqs-0) + └─ 1 [id I] + +Inner graphs: + +Scan{scan_fn, while_loop=False, inplace=all} [id B] + ← Composite{switch(lt(0, i0), 1, 0)} [id K] (inner_out_sit_sot-0) + └─ Subtensor{i, j, k} [id L] + ├─ *2- [id M] -> [id J] (inner_in_non_seqs-0) + ├─ ScalarFromTensor [id N] + │ └─ *0- [id O] -> [id D] (inner_in_seqs-0) + ├─ ScalarFromTensor [id P] + │ └─ *1- [id Q] -> [id E] (inner_in_sit_sot-0) + └─ 0 [id R] + +Composite{switch(lt(0, i0), 1, 0)} [id K] + ← Switch [id S] 'o0' + ├─ LT [id T] + │ ├─ 0 [id U] + │ └─ i0 [id V] + ├─ 1 [id W] + └─ 0 [id U] """ output_str = debugprint(out, file="str", print_op_info=True) diff --git a/tests/scan/test_rewriting.py b/tests/scan/test_rewriting.py index fd9c43b129..e9a6d437ca 100644 --- a/tests/scan/test_rewriting.py +++ b/tests/scan/test_rewriting.py @@ -9,7 +9,7 @@ from pytensor.compile.mode import get_default_mode from pytensor.configdefaults import config from pytensor.gradient import grad, jacobian -from pytensor.graph.basic import equal_computations +from pytensor.graph.basic import Constant, equal_computations from pytensor.graph.fg import FunctionGraph from pytensor.graph.replace import clone_replace from pytensor.scan.op import Scan @@ -742,7 +742,7 @@ def rnn_step1( utt.assert_allclose(f_opt_output, f_no_opt_output) def test_non_zero_init(self): - """Test the case where the initial value for the nitsot output is non-zero.""" + """Test the case where the initial value for the sitsot output is non-zero.""" input1 = tensor3() input2 = tensor3() @@ -759,8 +759,7 @@ def inner_fct(seq1, seq2, seq3, previous_output): init = pt.as_tensor_variable(np.random.normal(size=(3, 7))) - # Compile the function twice, once with the optimization and once - # without + # Compile the function twice, once with the optimization and once without opt_mode = mode.including("scan") h, _ = pytensor.scan( inner_fct, @@ -792,7 +791,7 @@ def inner_fct(seq1, seq2, seq3, previous_output): output_opt = f_opt(input1_value, input2_value, input3_value) output_no_opt = f_no_opt(input1_value, input2_value, input3_value) - utt.assert_allclose(output_opt, output_no_opt) + np.testing.assert_allclose(output_opt, output_no_opt) class TestScanMerge: @@ -1208,7 +1207,7 @@ def test_inplace3(self): class TestSaveMem: - mode = get_default_mode().including("scan_save_mem", "scan_save_mem") + mode = get_default_mode().including("scan_save_mem") def test_save_mem(self): rng = np.random.default_rng(utt.fetch_seed()) @@ -1295,11 +1294,27 @@ def f_rnn(u_t): [x1[:2], x2[4], x3[idx], x4[:idx], x5[-10], x6[-jdx], x7[:-jdx]], updates=updates, allow_input_downcast=True, - mode=self.mode, + mode=self.mode.excluding("scan_push_out_seq"), ) + # Check we actually have a Scan in the compiled function + [scan_node] = [ + node for node in f2.maker.fgraph.toposort() if isinstance(node.op, Scan) + ] + # get random initial values rng = np.random.default_rng(utt.fetch_seed()) - v_u = rng.uniform(-5.0, 5.0, size=(20,)) + v_u = rng.uniform(-5.0, 5.0, size=(20,)).astype(u.type.dtype) + + # Check the number of steps is actually reduced from 20 + n_steps = scan_node.inputs[0] + n_steps_fn = pytensor.function( + [u, idx, jdx], n_steps, accept_inplace=True, on_unused_input="ignore" + ) + assert n_steps_fn(u=v_u, idx=3, jdx=15) == 11 # x5[const=-10] requires 11 steps + assert n_steps_fn(u=v_u, idx=3, jdx=3) == 18 # x6[jdx=-3] requires 18 steps + assert n_steps_fn(u=v_u, idx=16, jdx=15) == 17 # x3[idx=16] requires 17 steps + assert n_steps_fn(u=v_u, idx=-5, jdx=15) == 16 # x3[idx=-5] requires 16 steps + assert n_steps_fn(u=v_u, idx=19, jdx=15) == 20 # x3[idx=19] requires 20 steps # compute the output in numpy tx1, tx2, tx3, tx4, tx5, tx6, tx7 = f2(v_u, 3, 15) @@ -1312,6 +1327,49 @@ def f_rnn(u_t): utt.assert_allclose(tx6, v_u[-15] + 6.0) utt.assert_allclose(tx7, v_u[:-15] + 7.0) + def test_save_mem_reduced_number_of_steps_constant(self): + x0 = pt.scalar("x0") + xs, _ = scan( + lambda xtm1: xtm1 + 1, + outputs_info=[x0], + n_steps=10, + ) + + fn = function([x0], xs[:5], mode=self.mode) + [scan_node] = [ + node for node in fn.maker.fgraph.toposort() if isinstance(node.op, Scan) + ] + n_steps = scan_node.inputs[0] + assert isinstance(n_steps, Constant) and n_steps.data == 5 + + np.testing.assert_allclose(fn(0), np.arange(1, 11)[:5]) + + def test_save_mem_cannot_reduce_constant_number_of_steps(self): + x0 = pt.scalar("x0") + [xs, ys], _ = scan( + lambda xtm1, ytm1: (xtm1 + 1, ytm1 - 1), + outputs_info=[x0, x0], + n_steps=10, + ) + + # Because of ys[-1] we need all the steps! + fn = function([x0], [xs[:5], ys[-1]], mode=self.mode) + [scan_node] = [ + node for node in fn.maker.fgraph.toposort() if isinstance(node.op, Scan) + ] + n_steps = scan_node.inputs[0] + assert isinstance(n_steps, Constant) and n_steps.data == 10 + + res_x, res_y = fn(0) + np.testing.assert_allclose( + res_x, + np.arange(1, 11)[:5], + ) + np.testing.assert_allclose( + res_y, + -np.arange(1, 11)[-1], + ) + def test_save_mem_store_steps(self): def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1): return ( diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index ebe07f4947..78ec97eff3 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -16,6 +16,7 @@ from pytensor.configdefaults import config from pytensor.gradient import grad from pytensor.graph import Constant +from pytensor.graph.basic import equal_computations from pytensor.graph.op import get_test_value from pytensor.graph.rewriting.utils import is_same_graph from pytensor.printing import pprint @@ -23,7 +24,7 @@ from pytensor.tensor import as_tensor, get_vector_length, vectorize from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle -from pytensor.tensor.math import exp, isinf +from pytensor.tensor.math import exp, isinf, lt, switch from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.shape import specify_shape from pytensor.tensor.subtensor import ( @@ -136,30 +137,41 @@ def test_unsupported_inputs(self, idx): def test_scalar_constant(self): a = as_scalar(0) length = lscalar() - res = get_canonical_form_slice(a, length) - assert isinstance(res[0].owner.op, ptb.ScalarFromTensor) - assert res[1] == 1 + res, direction = get_canonical_form_slice(a, length) + assert res == 0 + assert direction == 1 + + b = as_scalar(-1) + res, direction = get_canonical_form_slice(b, length) + assert equal_computations([res], [as_tensor(-1) + length]) + assert direction == 1 def test_tensor_constant(self): a = as_tensor(0) length = lscalar() - res = get_canonical_form_slice(a, length) - assert isinstance(res[0].owner.op, ptb.ScalarFromTensor) - assert res[1] == 1 + res, direction = get_canonical_form_slice(a, length) + assert equal_computations([res], [a]) + assert direction == 1 + + b = as_tensor(-1) + res, direction = get_canonical_form_slice(b, length) + assert equal_computations([res], [b + length]) + assert direction == 1 def test_symbolic_scalar(self): a = int16() length = lscalar() - res = get_canonical_form_slice(a, length) - assert res[0].owner.op, ptb.switch - assert res[1] == 1 + res, direction = get_canonical_form_slice(a, length) + a_t = as_tensor(a) + assert equal_computations([res], [switch(lt(a_t, 0), a_t + length, a_t)]) + assert direction == 1 def test_symbolic_tensor(self): a = lscalar() length = lscalar() - res = get_canonical_form_slice(a, length) - assert isinstance(res[0].owner.op, ptb.ScalarFromTensor) - assert res[1] == 1 + res, direction = get_canonical_form_slice(a, length) + assert equal_computations([res], [switch(lt(a, 0), a + length, a)]) + assert direction == 1 @pytest.mark.parametrize("int_fn", [int, np.int64, as_tensor, as_scalar]) def test_all_integer(self, int_fn):