diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9677615206..48ada81faf 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -193,7 +193,7 @@ jobs: else micromamba install --yes -q "python~=${PYTHON_VERSION}" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock; fi - if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi + micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi pip install pytest-sphinx diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index ffa27e5d5a..adf9732f5e 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -63,9 +63,8 @@ def register_linker(name, linker): # If a string is passed as the optimizer argument in the constructor # for Mode, it will be used as the key to retrieve the real optimizer # in this dictionary -exclude = [] -if not config.cxx: - exclude = ["cxx_only"] + +exclude = ["cxx_only", "BlasOpt"] OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude) # Even if multiple merge optimizer call will be there, this shouldn't # impact performance. @@ -346,6 +345,11 @@ def __setstate__(self, state): optimizer = predefined_optimizers[optimizer] if isinstance(optimizer, RewriteDatabaseQuery): self.provided_optimizer = optimizer + + # Force numba-required rewrites if using NumbaLinker + if isinstance(linker, NumbaLinker): + optimizer = optimizer.including("numba") + self._optimizer = optimizer self.call_time = 0 self.fn_time = 0 @@ -443,16 +447,20 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): # string as the key # Use VM_linker to allow lazy evaluation by default. FAST_COMPILE = Mode( - VMLinker(use_cloop=False, c_thunks=False), - RewriteDatabaseQuery(include=["fast_compile", "py_only"]), + NumbaLinker(), + # TODO: Fast_compile should just use python code, CHANGE ME! + RewriteDatabaseQuery( + include=["fast_compile", "numba"], + exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"], + ), +) +FAST_RUN = Mode( + NumbaLinker(), + RewriteDatabaseQuery( + include=["fast_run", "numba"], + exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"], + ), ) -if config.cxx: - FAST_RUN = Mode("cvm", "fast_run") -else: - FAST_RUN = Mode( - "vm", - RewriteDatabaseQuery(include=["fast_run", "py_only"]), - ) NUMBA = Mode( NumbaLinker(), @@ -565,6 +573,7 @@ def register_mode(name, mode): Add a `Mode` which can be referred to by `name` in `function`. """ + # TODO: Remove me if name in predefined_modes: raise ValueError(f"Mode name already taken: {name}") predefined_modes[name] = mode diff --git a/pytensor/configdefaults.py b/pytensor/configdefaults.py index ca3c44bf6d..9c8fecab33 100644 --- a/pytensor/configdefaults.py +++ b/pytensor/configdefaults.py @@ -370,11 +370,21 @@ def add_compile_configvars(): if rc == 0 and config.cxx != "": # Keep the default linker the same as the one for the mode FAST_RUN - linker_options = ["c|py", "py", "c", "c|py_nogc", "vm", "vm_nogc", "cvm_nogc"] + linker_options = [ + "cvm", + "c|py", + "py", + "c", + "c|py_nogc", + "vm", + "vm_nogc", + "cvm_nogc", + "jax", + ] else: # g++ is not present or the user disabled it, # linker should default to python only. - linker_options = ["py", "vm_nogc"] + linker_options = ["py", "vm", "vm_nogc", "jax"] if type(config).cxx.is_default: # If the user provided an empty value for cxx, do not warn. _logger.warning( @@ -388,7 +398,7 @@ def add_compile_configvars(): "linker", "Default linker used if the pytensor flags mode is Mode", # Not mutable because the default mode is cached after the first use. - EnumStr("cvm", linker_options, mutable=False), + EnumStr("numba", linker_options, mutable=False), in_c_key=False, ) diff --git a/pytensor/link/jax/linker.py b/pytensor/link/jax/linker.py index 06370b4514..5e1a87e31e 100644 --- a/pytensor/link/jax/linker.py +++ b/pytensor/link/jax/linker.py @@ -87,3 +87,6 @@ def create_thunk_inputs(self, storage_map): thunk_inputs.append(sinput) return thunk_inputs + + def __repr__(self): + return "JAXLinker()" diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 2f3cac6ea6..f961524fa0 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -40,6 +40,7 @@ from pytensor.tensor.slinalg import Solve from pytensor.tensor.type import TensorType from pytensor.tensor.type_other import MakeSlice, NoneConst +from pytensor.typed_list import TypedListType def global_numba_func(func): @@ -135,6 +136,8 @@ def get_numba_type( return CSCMatrixType(numba_dtype) raise NotImplementedError() + elif isinstance(pytensor_type, TypedListType): + return numba.types.List(get_numba_type(pytensor_type.ttype)) else: raise NotImplementedError(f"Numba type not implemented for {pytensor_type}") @@ -481,11 +484,11 @@ def numba_funcify_SpecifyShape(op, node, **kwargs): shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))] func_conditions = [ - f"assert x.shape[{i}] == {shape_input_names}" - for i, (shape_input, shape_input_names) in enumerate( + f"assert x.shape[{i}] == {eval_dim_name}, f'SpecifyShape: dim {{{i}}} of input has shape {{x.shape[{i}]}}, expected {{{eval_dim_name}.item()}}.'" + for i, (node_dim_input, eval_dim_name) in enumerate( zip(shape_inputs, shape_input_names, strict=True) ) - if shape_input is not NoneConst + if node_dim_input is not NoneConst ] func = dedent( diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 9fd81dadcf..c81cc89830 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -411,7 +411,15 @@ def numba_funcify_CAReduce(op, node, **kwargs): @numba_funcify.register(DimShuffle) -def numba_funcify_DimShuffle(op, node, **kwargs): +def numba_funcify_DimShuffle(op: DimShuffle, node, **kwargs): + if op.is_left_expand_dims and op.new_order.count("x") == 1: + # Most common case, numba compiles it more quickly + @numba_njit + def left_expand_dims(x): + return np.expand_dims(x, 0) + + return left_expand_dims + # We use `as_strided` to achieve the DimShuffle behavior of transposing and expanding/squezing dimensions in one call # Numba doesn't currently support multiple expand/squeeze, and reshape is limited to contiguous arrays. new_order = tuple(op._new_order) diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index e9b637b00f..b9121899ca 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -23,6 +23,7 @@ Composite, Identity, Mul, + Pow, Reciprocal, ScalarOp, Second, @@ -154,6 +155,21 @@ def numba_funcify_Switch(op, node, **kwargs): return numba_basic.global_numba_func(switch) +@numba_funcify.register(Pow) +def numba_funcify_Pow(op, node, **kwargs): + pow_dtype = node.inputs[1].type.dtype + + def pow(x, y): + return x**y + + # Work-around https://github.com/numba/numba/issues/9554 + # fast-math casuse kernel crash + patch_kwargs = {} + if pow_dtype.startswith("int"): + patch_kwargs["fastmath"] = False + return numba_basic.numba_njit(**patch_kwargs)(pow) + + def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op: str): """Create a Numba-compatible N-ary function from a binary function.""" unique_names = unique_name_generator(["binary_op_name"], suffix_sep="_") @@ -172,18 +188,64 @@ def {binary_op_name}({input_signature}): @numba_funcify.register(Add) def numba_funcify_Add(op, node, **kwargs): + match len(node.inputs): + case 2: + + def add(i0, i1): + return i0 + i1 + case 3: + + def add(i0, i1, i2): + return i0 + i1 + i2 + case 4: + + def add(i0, i1, i2, i3): + return i0 + i1 + i2 + i3 + case 5: + + def add(i0, i1, i2, i3, i4): + return i0 + i1 + i2 + i3 + i4 + case _: + add = None + + if add is not None: + return numba_basic.numba_njit(add) + signature = create_numba_signature(node, force_scalar=True) nary_add_fn = binary_to_nary_func(node.inputs, "add", "+") - return numba_basic.numba_njit(signature)(nary_add_fn) + return numba_basic.numba_njit(signature, cache=False)(nary_add_fn) @numba_funcify.register(Mul) def numba_funcify_Mul(op, node, **kwargs): + match len(node.inputs): + case 2: + + def mul(i0, i1): + return i0 * i1 + case 3: + + def mul(i0, i1, i2): + return i0 * i1 * i2 + case 4: + + def mul(i0, i1, i2, i3): + return i0 * i1 * i2 * i3 + case 5: + + def mul(i0, i1, i2, i3, i4): + return i0 * i1 * i2 * i3 * i4 + case _: + mul = None + + if mul is not None: + return numba_basic.numba_njit(mul) + signature = create_numba_signature(node, force_scalar=True) nary_add_fn = binary_to_nary_func(node.inputs, "mul", "*") - return numba_basic.numba_njit(signature)(nary_add_fn) + return numba_basic.numba_njit(signature, cache=False)(nary_add_fn) @numba_funcify.register(Cast) @@ -233,7 +295,7 @@ def numba_funcify_Composite(op, node, **kwargs): _ = kwargs.pop("storage_map", None) - composite_fn = numba_basic.numba_njit(signature)( + composite_fn = numba_basic.numba_njit(signature, cache=False)( numba_funcify(op.fgraph, squeeze_output=True, **kwargs) ) return composite_fn diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index ee9e183d16..c35d49c485 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -23,6 +23,60 @@ def numba_funcify_default_subtensor(op, node, **kwargs): """Create a Python function that assembles and uses an index on an array.""" + if isinstance(op, Subtensor) and len(op.idx_list) == 1: + # Hard code indices along first dimension to allow caching + [idx] = op.idx_list + + if isinstance(idx, slice): + slice_info = ( + idx.start is not None, + idx.stop is not None, + idx.step is not None, + ) + match slice_info: + case (False, False, False): + + def subtensor(x): + return x + + case (True, False, False): + + def subtensor(x, start): + return x[start:] + case (False, True, False): + + def subtensor(x, stop): + return x[:stop] + case (False, False, True): + + def subtensor(x, step): + return x[::step] + + case (True, True, False): + + def subtensor(x, start, stop): + return x[start:stop] + case (True, False, True): + + def subtensor(x, start, step): + return x[start::step] + case (False, True, True): + + def subtensor(x, stop, step): + return x[:stop:step] + + case (True, True, True): + + def subtensor(x, start, stop, step): + return x[start:stop:step] + + else: + + def subtensor(x, i): + return np.asarray(x[i]) + + return numba_njit(subtensor) + unique_names = unique_name_generator( ["subtensor", "incsubtensor", "z"], suffix_sep="_" ) @@ -100,7 +154,7 @@ def {function_name}({", ".join(input_names)}): function_name=function_name, global_env=globals() | {"np": np}, ) - return numba_njit(func, boundscheck=True) + return numba_njit(func, boundscheck=True, cache=False) @numba_funcify.register(AdvancedSubtensor) @@ -294,7 +348,9 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): if broadcast: @numba_njit(boundscheck=True) - def advancedincsubtensor1_inplace(x, val, idxs): + def advanced_incsubtensor1(x, val, idxs): + out = x if inplace else x.copy() + if val.ndim == x.ndim: core_val = val[0] elif val.ndim == 0: @@ -304,24 +360,28 @@ def advancedincsubtensor1_inplace(x, val, idxs): core_val = val for idx in idxs: - x[idx] = core_val - return x + out[idx] = core_val + return out else: @numba_njit(boundscheck=True) - def advancedincsubtensor1_inplace(x, vals, idxs): + def advanced_incsubtensor1(x, vals, idxs): + out = x if inplace else x.copy() + if not len(idxs) == len(vals): raise ValueError("The number of indices and values must match.") # no strict argument because incompatible with numba for idx, val in zip(idxs, vals): # noqa: B905 - x[idx] = val - return x + out[idx] = val + return out else: if broadcast: @numba_njit(boundscheck=True) - def advancedincsubtensor1_inplace(x, val, idxs): + def advanced_incsubtensor1(x, val, idxs): + out = x if inplace else x.copy() + if val.ndim == x.ndim: core_val = val[0] elif val.ndim == 0: @@ -331,29 +391,21 @@ def advancedincsubtensor1_inplace(x, val, idxs): core_val = val for idx in idxs: - x[idx] += core_val - return x + out[idx] += core_val + return out else: @numba_njit(boundscheck=True) - def advancedincsubtensor1_inplace(x, vals, idxs): + def advanced_incsubtensor1(x, vals, idxs): + out = x if inplace else x.copy() + if not len(idxs) == len(vals): raise ValueError("The number of indices and values must match.") # no strict argument because unsupported by numba # TODO: this doesn't come up in tests for idx, val in zip(idxs, vals): # noqa: B905 - x[idx] += val - return x - - if inplace: - return advancedincsubtensor1_inplace - - else: - - @numba_njit - def advancedincsubtensor1(x, vals, idxs): - x = x.copy() - return advancedincsubtensor1_inplace(x, vals, idxs) + out[idx] += val + return out - return advancedincsubtensor1 + return advanced_incsubtensor1 diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index 8f5972c058..30e89c4256 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -68,7 +68,7 @@ def numba_funcify_Alloc(op, node, **kwargs): shape_var_item_names = [f"{name}_item" for name in shape_var_names] shapes_to_items_src = indent( "\n".join( - f"{item_name} = to_scalar({shape_name})" + f"{item_name} = {shape_name}.item()" for item_name, shape_name in zip( shape_var_item_names, shape_var_names, strict=True ) @@ -86,7 +86,7 @@ def numba_funcify_Alloc(op, node, **kwargs): alloc_def_src = f""" def alloc(val, {", ".join(shape_var_names)}): - val_np = np.asarray(val) + val_np = val {shapes_to_items_src} scalar_shape = {create_tuple_string(shape_var_item_names)} {check_runtime_broadcast_src} diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index 74870e29bd..6ad6121719 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -35,6 +35,97 @@ def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on): on[...] = ton """ + # Hardcode some cases for numba caching + match (nin, nout): + case (1, 1): + + def func(i0, o0): + t0 = core_op_fn(i0) + o0[...] = t0 + case (1, 2): + + def func(i0, o0, o1): + t0, t1 = core_op_fn(i0) + o0[...] = t0 + o1[...] = t1 + case (1, 3): + + def func(i0, o0, o1, o2): + t0, t1, t2 = core_op_fn(i0) + o0[...] = t0 + o1[...] = t1 + o2[...] = t2 + + case (2, 1): + + def func(i0, i1, o0): + t0 = core_op_fn(i0, i1) + o0[...] = t0 + case (2, 2): + + def func(i0, i1, o0, o1): + t0, t1 = core_op_fn(i0, i1) + o0[...] = t0 + o1[...] = t1 + case (2, 3): + + def func(i0, i1, o0, o1, o2): + t0, t1, t2 = core_op_fn(i0, i1) + o0[...] = t0 + o1[...] = t1 + o2[...] = t2 + + case (3, 1): + + def func(i0, i1, i2, o0): + t0 = core_op_fn(i0, i1, i2) + o0[...] = t0 + + case (3, 2): + + def func(i0, i1, i2, o0, o1): + t0, t1 = core_op_fn(i0, i1, i2) + o0[...] = t0 + o1[...] = t1 + case (3, 3): + + def func(i0, i1, i2, o0, o1, o2): + t0, t1, t2 = core_op_fn(i0, i1, i2) + o0[...] = t0 + o1[...] = t1 + o2[...] = t2 + + case (4, 1): + + def func(i0, i1, i2, i3, o0): + t0 = core_op_fn(i0, i1, i2, i3) + o0[...] = t0 + + case (4, 2): + + def func(i0, i1, i2, i3, o0, o1): + t0, t1 = core_op_fn(i0, i1, i2, i3) + o0[...] = t0 + o1[...] = t1 + + case (5, 1): + + def func(i0, i1, i2, i3, i4, o0): + t0 = core_op_fn(i0, i1, i2, i3, i4) + o0[...] = t0 + + case (5, 2): + + def func(i0, i1, i2, i3, i4, o0, o1): + t0, t1 = core_op_fn(i0, i1, i2, i3, i4) + o0[...] = t0 + o1[...] = t1 + case _: + func = None + + if func is not None: + return numba_basic.numba_njit(func) + inputs = [f"i{i}" for i in range(nin)] outputs = [f"o{i}" for i in range(nout)] inner_outputs = [f"t{output}" for output in outputs] @@ -55,7 +146,7 @@ def store_core_outputs({inp_signature}, {out_signature}): func = compile_function_src( func_src, "store_core_outputs", {**globals(), **global_env} ) - return cast(Callable, numba_basic.numba_njit(func)) + return cast(Callable, numba_basic.numba_njit(func, cache=False)) _jit_options = { diff --git a/pytensor/link/numba/linker.py b/pytensor/link/numba/linker.py index 553c5ef217..4257cac07b 100644 --- a/pytensor/link/numba/linker.py +++ b/pytensor/link/numba/linker.py @@ -12,7 +12,10 @@ def fgraph_convert(self, fgraph, **kwargs): def jit_compile(self, fn): from pytensor.link.numba.dispatch.basic import numba_njit - jitted_fn = numba_njit(fn, no_cpython_wrapper=False, no_cfunc_wrapper=False) + # NUMBA can't cache our dynamically generated funcified_fgraph + jitted_fn = numba_njit( + fn, no_cpython_wrapper=False, no_cfunc_wrapper=False, cache=False + ) return jitted_fn def create_thunk_inputs(self, storage_map): @@ -35,3 +38,6 @@ def create_thunk_inputs(self, storage_map): thunk_inputs.append(sinput) return thunk_inputs + + def __repr__(self): + return "NumbaLinker()" diff --git a/tests/tensor/conv/test_abstract_conv.py b/tests/tensor/conv/test_abstract_conv.py index 23ba23e1e9..814f1eb80b 100644 --- a/tests/tensor/conv/test_abstract_conv.py +++ b/tests/tensor/conv/test_abstract_conv.py @@ -948,9 +948,9 @@ def run_gradinput( ) -@pytest.mark.skipif( - config.cxx == "", - reason="SciPy and cxx needed", +@pytest.mark.skipif(config.cxx == "", reason="cxx needed") +@pytest.mark.xfail( + reason="Involves Ops with no Python implementation for numba to use as fallback" ) class TestAbstractConvNoOptim(BaseTestConv2d): @classmethod @@ -1884,9 +1884,9 @@ def test_conv2d_grad_wrt_weights(self): ) -@pytest.mark.skipif( - config.cxx == "", - reason="SciPy and cxx needed", +@pytest.mark.skipif(config.cxx == "", reason="cxx needed") +@pytest.mark.xfail( + reason="Involves Ops with no Python implementation for numba to use as fallback" ) class TestGroupedConvNoOptim: conv = abstract_conv.AbstractConv2d @@ -2096,9 +2096,9 @@ def conv_gradinputs(filters_val, output_val): utt.verify_grad(conv_gradinputs, [kern, top], mode=self.mode, eps=1) -@pytest.mark.skipif( - config.cxx == "", - reason="SciPy and cxx needed", +@pytest.mark.skipif(config.cxx == "", reason="cxx needed") +@pytest.mark.xfail( + reason="Involves Ops with no Python implementation for numba to use as fallback" ) class TestGroupedConv3dNoOptim(TestGroupedConvNoOptim): conv = abstract_conv.AbstractConv3d diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index 1730ae46ac..6f04b61506 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -1,4 +1,5 @@ import copy +import re import numpy as np import pytest @@ -306,7 +307,9 @@ def test_inconsistent_shared(self, shape_unsafe): # Error raised by Alloc Op with pytest.raises( ValueError, - match=r"could not broadcast input array from shape \(3,7\) into shape \(6,7\)", + match=re.escape( + "cannot assign slice of shape (3, 7) from input of shape (6, 7)" + ), ): f() @@ -1203,6 +1206,7 @@ def test_sum_bool_upcast(self): f(5) +@pytest.mark.xfail(reason="Numba does not support float16") class TestLocalOptAllocF16(TestLocalOptAlloc): dtype = "float16" diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index dee0023efd..dfe9f2630f 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -150,7 +150,10 @@ ) -pytestmark = pytest.mark.filterwarnings("error") +pytestmark = pytest.mark.filterwarnings( + "error", + "ignore:Numba will use object mode:UserWarning", +) if config.mode == "FAST_COMPILE": mode_opt = "FAST_RUN" @@ -758,41 +761,43 @@ def check_allocs_in_fgraph(fgraph, n): def setup_method(self): self.rng = np.random.default_rng(seed=utt.fetch_seed()) - def test_alloc_constant_folding(self): + @pytest.mark.parametrize( + "subtensor_fn, expected_grad_n_alloc", + [ + # IncSubtensor1 + (lambda x: x[:60], 1), + # AdvancedIncSubtensor1 + (lambda x: x[np.arange(60)], 1), + # AdvancedIncSubtensor + (lambda x: x[np.arange(50), np.arange(50)], 1), + ], + ) + def test_alloc_constant_folding(self, subtensor_fn, expected_grad_n_alloc): test_params = np.asarray(self.rng.standard_normal(50 * 60), self.dtype) some_vector = vector("some_vector", dtype=self.dtype) some_matrix = some_vector.reshape((60, 50)) variables = self.shared(np.ones((50,), dtype=self.dtype)) - idx = constant(np.arange(50)) - for alloc_, (subtensor, n_alloc) in zip( - self.allocs, - [ - # IncSubtensor1 - (some_matrix[:60], 2), - # AdvancedIncSubtensor1 - (some_matrix[arange(60)], 2), - # AdvancedIncSubtensor - (some_matrix[idx, idx], 1), - ], - strict=True, - ): - derp = pt_sum(dense_dot(subtensor, variables)) - - fobj = pytensor.function([some_vector], derp, mode=self.mode) - grad_derp = pytensor.grad(derp, some_vector) - fgrad = pytensor.function([some_vector], grad_derp, mode=self.mode) + subtensor = subtensor_fn(some_matrix) - topo_obj = fobj.maker.fgraph.toposort() - assert sum(isinstance(node.op, type(alloc_)) for node in topo_obj) == 0 + derp = pt_sum(dense_dot(subtensor, variables)) + fobj = pytensor.function([some_vector], derp, mode=self.mode) + assert ( + sum(isinstance(node.op, Alloc) for node in fobj.maker.fgraph.apply_nodes) + == 0 + ) + # TODO: Assert something about the value if we bothered to call it? + fobj(test_params) - topo_grad = fgrad.maker.fgraph.toposort() - assert ( - sum(isinstance(node.op, type(alloc_)) for node in topo_grad) == n_alloc - ), (alloc_, subtensor, n_alloc, topo_grad) - fobj(test_params) - fgrad(test_params) + grad_derp = pytensor.grad(derp, some_vector) + fgrad = pytensor.function([some_vector], grad_derp, mode=self.mode) + assert ( + sum(isinstance(node.op, Alloc) for node in fgrad.maker.fgraph.apply_nodes) + == expected_grad_n_alloc + ) + # TODO: Assert something about the value if we bothered to call it? + fgrad(test_params) def test_alloc_output(self): val = constant(self.rng.standard_normal((1, 1)), dtype=self.dtype)