Skip to content

Commit 330bbcc

Browse files
committed
Bump PyTensor dependency
1 parent 0c6d0df commit 330bbcc

13 files changed

+33
-34
lines changed

conda-envs/environment-dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ dependencies:
1313
- numpy>=1.15.0
1414
- pandas>=0.24.0
1515
- pip
16-
- pytensor>=2.23,<2.24
16+
- pytensor>=2.25.1,<2.26
1717
- python-graphviz
1818
- networkx
1919
- scipy>=1.4.1

conda-envs/environment-docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ dependencies:
1111
- numpy>=1.15.0
1212
- pandas>=0.24.0
1313
- pip
14-
- pytensor>=2.23,<2.24
14+
- pytensor>=2.25.1,<2.26
1515
- python-graphviz
1616
- rich>=13.7.1
1717
- scipy>=1.4.1

conda-envs/environment-jax.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ dependencies:
2020
- numpyro>=0.8.0
2121
- pandas>=0.24.0
2222
- pip
23-
- pytensor>=2.23,<2.24
23+
- pytensor>=2.25.1,<2.26
2424
- python-graphviz
2525
- networkx
2626
- rich>=13.7.1

conda-envs/environment-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ dependencies:
1616
- numpy>=1.15.0
1717
- pandas>=0.24.0
1818
- pip
19-
- pytensor>=2.23,<2.24
19+
- pytensor>=2.25.1,<2.26
2020
- python-graphviz
2121
- networkx
2222
- rich>=13.7.1

conda-envs/windows-environment-dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ dependencies:
1313
- numpy>=1.15.0
1414
- pandas>=0.24.0
1515
- pip
16-
- pytensor>=2.23,<2.24
16+
- pytensor>=2.25.1,<2.26
1717
- python-graphviz
1818
- networkx
1919
- rich>=13.7.1

conda-envs/windows-environment-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ dependencies:
1616
- numpy>=1.15.0
1717
- pandas>=0.24.0
1818
- pip
19-
- pytensor>=2.23,<2.24
19+
- pytensor>=2.25.1,<2.26
2020
- python-graphviz
2121
- networkx
2222
- rich>=13.7.1

pymc/distributions/distribution.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -836,11 +836,11 @@ def create_partial_observed_rv(
836836
if can_rewrite:
837837
masked_rv = rv[mask]
838838
fgraph = FunctionGraph(outputs=[masked_rv], clone=False, features=[ShapeFeature()])
839-
[unobserved_rv] = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner)
839+
unobserved_rv = local_subtensor_rv_lift.transform(fgraph, masked_rv.owner)[masked_rv]
840840

841841
antimasked_rv = rv[antimask]
842842
fgraph = FunctionGraph(outputs=[antimasked_rv], clone=False, features=[ShapeFeature()])
843-
[observed_rv] = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner)
843+
observed_rv = local_subtensor_rv_lift.transform(fgraph, antimasked_rv.owner)[antimasked_rv]
844844

845845
# Make a clone of the observedRV, with a distinct rng so that observed and
846846
# unobserved are never treated as equivalent (and mergeable) nodes by pytensor.

pymc/logprob/order.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -99,18 +99,20 @@ def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariab
9999
if not all(params.type.broadcastable):
100100
return None
101101

102-
# Check whether axis covers all dimensions
103-
axis = set(node.op.axis)
104-
base_var_dims = set(range(base_var.ndim))
105-
if axis != base_var_dims:
106-
return None
102+
if node.op.axis is None:
103+
axis = tuple(range(base_var.ndim))
104+
else:
105+
# Check whether axis covers all dimensions
106+
axis = tuple(sorted(node.op.axis))
107+
if axis != tuple(range(base_var.ndim)):
108+
return None
107109

108110
# distinguish measurable discrete and continuous (because logprob is different)
109111
measurable_max: Max
110112
if base_var.type.dtype.startswith("int"):
111-
measurable_max = MeasurableMaxDiscrete(list(axis))
113+
measurable_max = MeasurableMaxDiscrete(axis)
112114
else:
113-
measurable_max = MeasurableMax(list(axis))
115+
measurable_max = MeasurableMax(axis)
114116

115117
max_rv_node = measurable_max.make_node(base_var)
116118
max_rv = max_rv_node.outputs
@@ -206,21 +208,23 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Apply) -> list[TensorVa
206208
if not all(params.type.broadcastable):
207209
return None
208210

209-
# Check whether axis is supported or not
210-
axis = set(node.op.axis)
211-
base_var_dims = set(range(base_var.ndim))
212-
if axis != base_var_dims:
213-
return None
211+
if node.op.axis is None:
212+
axis = tuple(range(base_var.ndim))
213+
else:
214+
# Check whether axis is supported or not
215+
axis = tuple(sorted(node.op.axis))
216+
if axis != tuple(range(base_var.ndim)):
217+
return None
214218

215219
if not rv_map_feature.request_measurable([base_rv]):
216220
return None
217221

218222
# distinguish measurable discrete and continuous (because logprob is different)
219223
measurable_min: Max
220224
if base_rv.type.dtype.startswith("int"):
221-
measurable_min = MeasurableDiscreteMaxNeg(list(axis))
225+
measurable_min = MeasurableDiscreteMaxNeg(axis)
222226
else:
223-
measurable_min = MeasurableMaxNeg(list(axis))
227+
measurable_min = MeasurableMaxNeg(axis)
224228

225229
return measurable_min.make_node(base_rv).outputs
226230

pymc/logprob/rewriting.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@
7272
from pytensor.tensor.rewriting.basic import register_canonicalize
7373
from pytensor.tensor.rewriting.math import local_exp_over_1_plus_exp
7474
from pytensor.tensor.rewriting.shape import ShapeFeature
75-
from pytensor.tensor.rewriting.uncanonicalize import local_max_and_argmax
7675
from pytensor.tensor.subtensor import (
7776
AdvancedIncSubtensor,
7877
AdvancedIncSubtensor1,
@@ -374,12 +373,6 @@ def incsubtensor_rv_replace(fgraph, node):
374373

375374
logprob_rewrites_db.register("measurable_ir_rewrites", measurable_ir_rewrites_db, "basic")
376375

377-
# Split max_and_argmax
378-
# We only register this in the measurable IR db because max does not have a grad implemented
379-
# And running this on any MaxAndArgmax would lead to issues: https://github.com/pymc-devs/pymc/issues/7251
380-
# This special registering can be removed after https://github.com/pymc-devs/pytensor/issues/334 is fixed
381-
measurable_ir_rewrites_db.register("local_max_and_argmax", local_max_and_argmax, "basic")
382-
383376
# These rewrites push random/measurable variables "down", making them closer to
384377
# (or eventually) the graph outputs. Often this is done by lifting other `Op`s
385378
# "up" through the random/measurable variables and into their inputs.

pymc/logprob/tensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
# SOFTWARE.
3636

3737

38+
from pathlib import Path
39+
3840
import pytensor
3941

4042
from pytensor import tensor as pt
@@ -237,7 +239,7 @@ class MeasurableDimShuffle(DimShuffle):
237239

238240
# Need to get the absolute path of `c_func_file`, otherwise it tries to
239241
# find it locally and fails when a new `Op` is initialized
240-
c_func_file = DimShuffle.get_path(DimShuffle.c_func_file)
242+
c_func_file = str(DimShuffle.get_path(Path(DimShuffle.c_func_file)))
241243

242244

243245
MeasurableVariable.register(MeasurableDimShuffle)

pymc/pytensorf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
graph_inputs,
3737
walk,
3838
)
39-
from pytensor.graph.fg import FunctionGraph
39+
from pytensor.graph.fg import FunctionGraph, Output
4040
from pytensor.graph.op import Op
4141
from pytensor.scalar.basic import Cast
4242
from pytensor.scan.op import Scan
@@ -897,7 +897,7 @@ def find_default_update(clients, rng: Variable) -> None | Variable:
897897
[client, _] = rng_clients[0]
898898

899899
# RNG is an output of the function, this is not a problem
900-
if client == "output":
900+
if isinstance(client.op, Output):
901901
return rng
902902

903903
# RNG is used by another operator, which should output an update for the RNG

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ numpydoc
1717
pandas>=0.24.0
1818
polyagamma
1919
pre-commit>=2.8.0
20-
pytensor>=2.23,<2.24
20+
pytensor>=2.25.1,<2.26
2121
pytest-cov>=2.5
2222
pytest>=3.0
2323
rich>=13.7.1

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ cachetools>=4.2.1
33
cloudpickle
44
numpy>=1.15.0
55
pandas>=0.24.0
6-
pytensor>=2.23,<2.24
6+
pytensor>=2.25.1,<2.26
77
rich>=13.7.1
88
scipy>=1.4.1
99
threadpoolctl>=3.1.0,<4.0.0

0 commit comments

Comments
 (0)