Closed
Description
Describe the issue:
ZeroSumNormal does not follow floatX when it is set to float32.
The issue has been mentioned here, but the PR (merged in 5.6.0) doesn't seem to resolve it.
Also related: #6871
Reproduceable code example:
import pymc as pm
import pytensor
with pytensor.config.change_flags(floatX="float32", warn_float64="raise"):
with pm.Model():
# no issues here
a = pm.Normal("a", 0, 1)
with pm.Model():
# errors here
b = pm.ZeroSumNormal("b", 1, shape=(5,))
Error message:
---------------------------------------------------------------------------
Exception Traceback (most recent call last)
Input In [10], in <cell line: 5>()
8 a = pm.Normal("a", 0, 1)
10 with pm.Model():
11 # errors here
---> 12 b = pm.ZeroSumNormal("b", 1, shape=(5,))
File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pymc/distributions/multivariate.py:2512, in ZeroSumNormal.__new__(cls, zerosum_axes, n_zerosum_axes, support_shape, dims, *args, **kwargs)
2502 n_zerosum_axes = cls.check_zerosum_axes(n_zerosum_axes)
2504 support_shape = get_support_shape(
2505 support_shape=support_shape,
2506 shape=None, # Shape will be checked in `cls.dist`
(...)
2509 ndim_supp=n_zerosum_axes,
2510 )
-> 2512 return super().__new__(
2513 cls,
2514 *args,
2515 n_zerosum_axes=n_zerosum_axes,
2516 support_shape=support_shape,
2517 dims=dims,
2518 **kwargs,
2519 )
File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pymc/distributions/distribution.py:316, in Distribution.__new__(cls, name, rng, dims, initval, observed, total_size, transform, *args, **kwargs)
312 kwargs["shape"] = tuple(observed.shape)
314 rv_out = cls.dist(*args, **kwargs)
--> 316 rv_out = model.register_rv(
317 rv_out,
318 name,
319 observed,
320 total_size,
321 dims=dims,
322 transform=transform,
323 initval=initval,
324 )
326 # add in pretty-printing support
327 rv_out.str_repr = types.MethodType(str_for_dist, rv_out)
File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pymc/model.py:1289, in Model.register_rv(self, rv_var, name, observed, total_size, dims, transform, initval)
1287 raise ValueError("total_size can only be passed to observed RVs")
1288 self.free_RVs.append(rv_var)
-> 1289 self.create_value_var(rv_var, transform)
1290 self.add_named_variable(rv_var, dims)
1291 self.set_initval(rv_var, initval)
File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pymc/model.py:1442, in Model.create_value_var(self, rv_var, transform, value_var)
1439 value_var.tag.test_value = rv_var.tag.test_value
1440 else:
1441 # Create value variable with the same type as the transformed RV
-> 1442 value_var = transform.forward(rv_var, *rv_var.owner.inputs).type()
1443 value_var.name = f"{rv_var.name}_{transform.name}__"
1444 value_var.tag.transform = transform
File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pymc/distributions/transforms.py:320, in ZeroSumTransform.forward(self, value, *rv_inputs)
318 def forward(self, value, *rv_inputs):
319 for axis in self.zerosum_axes:
--> 320 value = extend_axis_rev(value, axis=axis)
321 return value
File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pymc/distributions/transforms.py:348, in extend_axis_rev(array, axis)
345 n = array.shape[normalized_axis]
346 last = pt.take(array, [-1], axis=normalized_axis)
--> 348 sum_vals = -last * pt.sqrt(n)
349 norm = sum_vals / (pt.sqrt(n) + n)
350 slice_before = (slice(None, None),) * normalized_axis
File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pytensor/graph/op.py:304, in Op.__call__(self, *inputs, **kwargs)
262 r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
263
264 This method is just a wrapper around :meth:`Op.make_node`.
(...)
301
302 """
303 return_list = kwargs.pop("return_list", False)
--> 304 node = self.make_node(*inputs, **kwargs)
306 if config.compute_test_value != "off":
307 compute_test_value(node)
File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pytensor/tensor/elemwise.py:497, in Elemwise.make_node(self, *inputs)
495 inputs = [as_tensor_variable(i) for i in inputs]
496 out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs)
--> 497 outputs = [
498 TensorType(dtype=dtype, shape=shape)()
499 for dtype, shape in zip(out_dtypes, out_shapes)
500 ]
501 return Apply(self, inputs, outputs)
File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pytensor/tensor/elemwise.py:498, in <listcomp>(.0)
495 inputs = [as_tensor_variable(i) for i in inputs]
496 out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs)
497 outputs = [
--> 498 TensorType(dtype=dtype, shape=shape)()
499 for dtype, shape in zip(out_dtypes, out_shapes)
500 ]
501 return Apply(self, inputs, outputs)
File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pytensor/graph/type.py:228, in Type.__call__(self, name)
219 def __call__(self, name: Optional[str] = None) -> variable_type:
220 """Return a new `Variable` instance of Type `self`.
221
222 Parameters
(...)
226
227 """
--> 228 return utils.add_tag_trace(self.make_variable(name))
File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pytensor/graph/type.py:200, in Type.make_variable(self, name)
191 def make_variable(self, name: Optional[str] = None) -> variable_type:
192 """Return a new `Variable` instance of this `Type`.
193
194 Parameters
(...)
198
199 """
--> 200 return self.variable_type(self, None, name=name)
File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pytensor/tensor/var.py:860, in TensorVariable.__init__(self, type, owner, index, name)
858 warnings.warn(msg, stacklevel=1 + nb_rm)
859 elif config.warn_float64 == "raise":
--> 860 raise Exception(msg)
861 elif config.warn_float64 == "pdb":
862 import pdb
Exception: You are creating a TensorVariable with float64 dtype. You requested an action via the PyTensor flag warn_float64={ignore,warn,raise,pdb}.
PyMC version information:
PyMC/PyMC3 Version: 5.7.2
PyTensor/Aesara Version: 2.14.2
Python Version: 3.10.12
Operating system: Darwin arm64
How did you install PyMC/PyMC3: pip
Context for the issue:
ZeroSumNormal
is needed for hierarchical models, and using float32
would be preferable since it gives sufficient precision with less memory.