Closed
Description
Describe the issue:
Dirichlet distribution ignores floatX config and creates float64 variables in the graph
Reproduceable code example:
import pymc as pm
import pytensor.tensor as pt
import pytensor
def test_dirichlet():
with pm.Model() as model:
c = pm.floatX([1, 1, 1])
print(c, c.dtype)
d = pm.Dirichlet("a", c)
print(model.point_logps())
with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
test_dirichlet()
Error message:
---------------------------------------------------------------------------
Exception Traceback (most recent call last)
Cell In[10], line 2
1 with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
----> 2 test_dirichlet()
Cell In[8], line 5, in test_dirichlet()
3 c = pm.floatX([1, 1, 1])
4 print(c, c.dtype)
----> 5 d = pm.Dirichlet("a", c)
6 print(model.point_logps())
File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pymc/distributions/distribution.py:314, in Distribution.__new__(cls, name, rng, dims, initval, observed, total_size, transform, *args, **kwargs)
310 kwargs["shape"] = tuple(observed.shape)
312 rv_out = cls.dist(*args, **kwargs)
--> 314 rv_out = model.register_rv(
315 rv_out,
316 name,
317 observed,
318 total_size,
319 dims=dims,
320 transform=transform,
321 initval=initval,
322 )
324 # add in pretty-printing support
325 rv_out.str_repr = types.MethodType(str_for_dist, rv_out)
File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pymc/model.py:1333, in Model.register_rv(self, rv_var, name, observed, total_size, dims, transform, initval)
1331 raise ValueError("total_size can only be passed to observed RVs")
1332 self.free_RVs.append(rv_var)
-> 1333 self.create_value_var(rv_var, transform)
1334 self.add_named_variable(rv_var, dims)
1335 self.set_initval(rv_var, initval)
File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pymc/model.py:1526, in Model.create_value_var(self, rv_var, transform, value_var)
1523 value_var.tag.test_value = rv_var.tag.test_value
1524 else:
1525 # Create value variable with the same type as the transformed RV
-> 1526 value_var = transform.forward(rv_var, *rv_var.owner.inputs).type()
1527 value_var.name = f"{rv_var.name}_{transform.name}__"
1528 value_var.tag.transform = transform
File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pymc/logprob/transforms.py:985, in SimplexTransform.forward(self, value, *inputs)
983 def forward(self, value, *inputs):
984 log_value = pt.log(value)
--> 985 shift = pt.sum(log_value, -1, keepdims=True) / value.shape[-1]
986 return log_value[..., :-1] - shift
File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pytensor/tensor/var.py:173, in _tensor_py_operators.__truediv__(self, other)
172 def __truediv__(self, other):
--> 173 return at.math.true_div(self, other)
File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pytensor/graph/op.py:295, in Op.__call__(self, *inputs, **kwargs)
253 r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
254
255 This method is just a wrapper around :meth:`Op.make_node`.
(...)
292
293 """
294 return_list = kwargs.pop("return_list", False)
--> 295 node = self.make_node(*inputs, **kwargs)
297 if config.compute_test_value != "off":
298 compute_test_value(node)
File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pytensor/tensor/elemwise.py:486, in Elemwise.make_node(self, *inputs)
484 inputs = [as_tensor_variable(i) for i in inputs]
485 out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs)
--> 486 outputs = [
487 TensorType(dtype=dtype, shape=shape)()
488 for dtype, shape in zip(out_dtypes, out_shapes)
489 ]
490 return Apply(self, inputs, outputs)
File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pytensor/tensor/elemwise.py:487, in <listcomp>(.0)
484 inputs = [as_tensor_variable(i) for i in inputs]
485 out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs)
486 outputs = [
--> 487 TensorType(dtype=dtype, shape=shape)()
488 for dtype, shape in zip(out_dtypes, out_shapes)
489 ]
490 return Apply(self, inputs, outputs)
File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pytensor/graph/type.py:228, in Type.__call__(self, name)
219 def __call__(self, name: Optional[str] = None) -> variable_type:
220 """Return a new `Variable` instance of Type `self`.
221
222 Parameters
(...)
226
227 """
--> 228 return utils.add_tag_trace(self.make_variable(name))
File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pytensor/graph/type.py:200, in Type.make_variable(self, name)
191 def make_variable(self, name: Optional[str] = None) -> variable_type:
192 """Return a new `Variable` instance of this `Type`.
193
194 Parameters
(...)
198
199 """
--> 200 return self.variable_type(self, None, name=name)
File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pytensor/tensor/var.py:863, in TensorVariable.__init__(self, type, owner, index, name)
861 warnings.warn(msg, stacklevel=1 + nb_rm)
862 elif config.warn_float64 == "raise":
--> 863 raise Exception(msg)
864 elif config.warn_float64 == "pdb":
865 import pdb
Exception: You are creating a TensorVariable with float64 dtype. You requested an action via the PyTensor flag warn_float64={ignore,warn,raise,pdb}.
PyMC version information:
master
Context for the issue:
related to pymc-devs/pymc-extras#182