Description
Describe the issue:
This may be related to pymc-devs/pymc#6779
Description:
When creating a model with floatX="float32"
that includes a Dirichlet distribution a single distribution, the floatX
assignment is respected. When creating a model with a Dirichlet distribution as well as another distribution two distributions, however, the floatX
assignment is NOT respected, but only upon sampling. This is a weird bug.
Expected Behavior
The model should respect floatX
in all cases.
Actual Behavior
When the model includes a Dirichlet distribution and then ANY other distribution two distributions, the graph includes float64
despite the request that floatX="float32"
.
Minimum Working Example
In the following MWE, I create four models. The first has one Dirichlet distribution, the second has one Normal distribution, and the remaining two include a Dirichlet distribution and then either a Normal or HalfCauchy distribution.
The first two models sample without issue, and floatX
is respected.
The second and third models raise float64
errors during sampling. The error appears after model.point_logps()
, which is what was all that was being checked in pymc-devs/pymc#6779
The output (with truncated error messages) is appended below:
pytensor version: 2.18.6
pymc version: 5.10.4
pytensor.config.floatX = float64
test_dirichlet
pytensor.config.floatX = float32
foo float32
{'foo': -1.5}
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [foo]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 2 seconds.ns, 0 divergences]
test_normal
pytensor.config.floatX = float32
foo float32
{'foo': -0.92}
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [foo]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 2 seconds.ns, 0 divergences]
test_dirichlet_normal
pytensor.config.floatX = float32
foo float32
bar float32
{'foo': -1.5, 'bar': -0.92}
ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_sum_make_vector
ERROR (pytensor.graph.rewriting.basic): node: Sum{axes=None}(MakeVector{dtype='float32'}.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
...
Exception: You are creating a TensorVariable with float64 dtype. You requested an action via the PyTensor flag warn_float64={ignore,warn,raise,pdb}.
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [foo, bar]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 3 seconds.ns, 0 divergences]
test_dirichlet_halfcauchy
pytensor.config.floatX = float32
foo float32
bar float32
{'foo': -1.5, 'bar': -1.14}
ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_sum_make_vector
ERROR (pytensor.graph.rewriting.basic): node: Sum{axes=None}(MakeVector{dtype='float32'}.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
...
Exception: You are creating a TensorVariable with float64 dtype. You requested an action via the PyTensor flag warn_float64={ignore,warn,raise,pdb}.
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [foo, bar]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 3 seconds.ns, 0 divergences]
Note that the output of print(model.point_logps())
demonstrates that the error occurs after model.point_logps()
. The error occurs during sampling.
Reproduceable code example:
import pytensor
import pytensor.tensor as pt
import pymc as pm
print("pytensor version:", pytensor.__version__)
print("pymc version:", pm.__version__)
print("pytensor.config.floatX = ", pytensor.config.floatX)
print()
def test_dirichlet():
print("test_dirichlet")
print("pytensor.config.floatX = ", pytensor.config.floatX)
with pm.Model() as model:
foo = pm.Dirichlet("foo", a=pt.ones(3))
print(foo, foo.dtype)
print(model.point_logps())
with model:
trace = pm.sample()
print()
def test_normal():
print("test_normal")
print("pytensor.config.floatX = ", pytensor.config.floatX)
with pm.Model() as model:
foo = pm.Normal("foo", mu=0.0, sigma=1.0)
print(foo, foo.dtype)
print(model.point_logps())
with model:
trace = pm.sample()
print()
def test_dirichlet_normal():
print("test_dirichlet_normal")
print("pytensor.config.floatX = ", pytensor.config.floatX)
with pm.Model() as model:
foo = pm.Dirichlet("foo", a=pt.ones(3))
print(foo, foo.dtype)
bar = pm.Normal("bar", mu=0.0, sigma=1.0)
print(bar, bar.dtype)
print(model.point_logps())
with model:
trace = pm.sample()
print()
def test_dirichlet_halfcauchy():
print("test_dirichlet_halfcauchy")
print("pytensor.config.floatX = ", pytensor.config.floatX)
with pm.Model() as model:
foo = pm.Dirichlet("foo", a=pt.ones(3))
print(foo, foo.dtype)
bar = pm.HalfCauchy("bar", beta=1.0)
print(bar, bar.dtype)
print(model.point_logps())
with model:
trace = pm.sample()
print()
with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
test_dirichlet()
with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
test_normal()
with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
test_dirichlet_normal()
with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
test_dirichlet_halfcauchy()
Error message:
ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_sum_make_vector
ERROR (pytensor.graph.rewriting.basic): node: Sum{axes=None}(MakeVector{dtype='float32'}.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1922, in process_node
replacements = node_rewriter.transform(fgraph, node)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1082, in transform
return self.fn(fgraph, node)
^^^^^^^^^^^^^^^^^^^^^
File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/tensor/rewriting/basic.py", line 950, in local_sum_make_vector
add(*[cast(value, acc_dtype) for value in elements]), out_dtype
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/tensor/rewriting/basic.py", line 950, in <listcomp>
add(*[cast(value, acc_dtype) for value in elements]), out_dtype
^^^^^^^^^^^^^^^^^^^^^^
File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/tensor/basic.py", line 763, in cast
return _cast_mapping[dtype_name](x)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/graph/op.py", line 295, in __call__
node = self.make_node(*inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/tensor/elemwise.py", line 484, in make_node
outputs = [
^
File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/tensor/elemwise.py", line 485, in <listcomp>
TensorType(dtype=dtype, shape=shape)()
File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/graph/type.py", line 228, in __call__
return utils.add_tag_trace(self.make_variable(name))
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/graph/type.py", line 200, in make_variable
return self.variable_type(self, None, name=name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pytensor/tensor/variable.py", line 900, in __init__
raise Exception(msg)
Exception: You are creating a TensorVariable with float64 dtype. You requested an action via the PyTensor flag warn_float64={ignore,warn,raise,pdb}.
PyMC version information:
Context for the issue:
Models that include a Dirichlet distribution as well as any other distribution cannot use float32
.