diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index 85e6694a95..fcfe678215 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -13,7 +13,7 @@ dependencies: - numpy>=1.15.0 - pandas>=0.24.0 - pip -- pytensor>=2.25.1,<2.26 +- pytensor>=2.26.2,<2.27 - python-graphviz - networkx - scipy>=1.4.1 diff --git a/conda-envs/environment-docs.yml b/conda-envs/environment-docs.yml index 86097c5ab3..b937978375 100644 --- a/conda-envs/environment-docs.yml +++ b/conda-envs/environment-docs.yml @@ -11,7 +11,7 @@ dependencies: - numpy>=1.15.0 - pandas>=0.24.0 - pip -- pytensor>=2.25.1,<2.26 +- pytensor>=2.26.2,<2.27 - python-graphviz - rich>=13.7.1 - scipy>=1.4.1 diff --git a/conda-envs/environment-jax.yml b/conda-envs/environment-jax.yml index 97d25dd5b8..0c0b0f91fd 100644 --- a/conda-envs/environment-jax.yml +++ b/conda-envs/environment-jax.yml @@ -20,7 +20,7 @@ dependencies: - numpyro>=0.8.0 - pandas>=0.24.0 - pip -- pytensor>=2.25.1,<2.26 +- pytensor>=2.26.2,<2.27 - python-graphviz - networkx - rich>=13.7.1 diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 58cde0d327..9ab2eac735 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -16,7 +16,7 @@ dependencies: - numpy>=1.15.0 - pandas>=0.24.0 - pip -- pytensor>=2.25.1,<2.26 +- pytensor>=2.26.2,<2.27 - python-graphviz - networkx - rich>=13.7.1 diff --git a/conda-envs/windows-environment-dev.yml b/conda-envs/windows-environment-dev.yml index 6d785e2cac..309f7c4fb4 100644 --- a/conda-envs/windows-environment-dev.yml +++ b/conda-envs/windows-environment-dev.yml @@ -13,7 +13,7 @@ dependencies: - numpy>=1.15.0 - pandas>=0.24.0 - pip -- pytensor>=2.25.1,<2.26 +- pytensor>=2.26.2,<2.27 - python-graphviz - networkx - rich>=13.7.1 diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index fd17c31711..a3952c5dd9 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -16,7 +16,7 @@ dependencies: - numpy>=1.15.0 - pandas>=0.24.0 - pip -- pytensor>=2.25.1,<2.26 +- pytensor>=2.26.2,<2.27 - python-graphviz - networkx - rich>=13.7.1 diff --git a/pymc/logprob/tensor.py b/pymc/logprob/tensor.py index 5503ce32b7..abb5df2ab5 100644 --- a/pymc/logprob/tensor.py +++ b/pymc/logprob/tensor.py @@ -286,9 +286,9 @@ def find_measurable_dimshuffles(fgraph, node) -> list[TensorVariable] | None: base_var = node.inputs[0] - measurable_dimshuffle = MeasurableDimShuffle(node.op.input_broadcastable, node.op.new_order)( - base_var - ) + measurable_dimshuffle = MeasurableDimShuffle( + input_ndim=node.op.input_ndim, new_order=node.op.new_order + )(base_var) assert isinstance(measurable_dimshuffle, TensorVariable) return [measurable_dimshuffle] diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 213831c9f1..d0360e0131 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -45,6 +45,7 @@ from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.type import RandomType from pytensor.tensor.random.var import RandomGeneratorSharedVariable +from pytensor.tensor.rewriting.basic import topo_unconditional_constant_folding from pytensor.tensor.rewriting.shape import ShapeFeature from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1 @@ -1057,7 +1058,7 @@ def compile_pymc( def constant_fold( xs: Sequence[TensorVariable], raise_not_constant: bool = True -) -> tuple[np.ndarray, ...]: +) -> tuple[np.ndarray | Variable, ...]: """Use constant folding to get constant values of a graph. Parameters @@ -1072,8 +1073,12 @@ def constant_fold( """ fg = FunctionGraph(outputs=xs, features=[ShapeFeature()], copy_inputs=False, clone=True) - # By default, rewrite_graph includes canonicalize which includes constant-folding as the final rewrite - folded_xs = rewrite_graph(fg).outputs + # The default rewrite_graph includes a constand_folding that is not always applied. + # We use an unconditional constant_folding as the last pass to ensure a thorough constant folding. + rewrite_graph(fg) + topo_unconditional_constant_folding.apply(fg) + + folded_xs = fg.outputs if raise_not_constant and not all(isinstance(folded_x, Constant) for folded_x in folded_xs): raise NotConstantValueError diff --git a/requirements-dev.txt b/requirements-dev.txt index 082eab73ce..d98bf2ec1e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -17,7 +17,7 @@ numpydoc pandas>=0.24.0 polyagamma pre-commit>=2.8.0 -pytensor>=2.25.1,<2.26 +pytensor>=2.26.2,<2.27 pytest-cov>=2.5 pytest>=3.0 rich>=13.7.1 diff --git a/requirements.txt b/requirements.txt index b59ca29127..05dcb1cdb1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ cachetools>=4.2.1 cloudpickle numpy>=1.15.0 pandas>=0.24.0 -pytensor>=2.25.1,<2.26 +pytensor>=2.26.1,<2.27 rich>=13.7.1 scipy>=1.4.1 threadpoolctl>=3.1.0,<4.0.0 diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index b3564cac1f..562bb49b55 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -696,6 +696,11 @@ def test_inputs_preserved(self): (out_shape,) = constant_fold((out.shape[0],), raise_not_constant=False) assert out_shape is a + def test_constant_fold_alloc(self): + # By default, Alloc outputs cannot be constant folded + x = pt.alloc(pt.arange(5), 2, 5) + np.testing.assert_allclose(constant_fold([x])[0], np.broadcast_to(np.arange(5), (2, 5))) + def test_replace_vars_in_graphs(): inp = shared(0.0, name="inp")