Closed
Description
Describe the issue:
You can't forward sample from a half-normal distribution in JAX mode
Reproduceable code example:
import pymc as pm
from pymc.pytensorf import get_mode
with pm.Model() as mod:
x = pm.HalfNormal('x')
prior = pm.sample_prior_predictive(compile_kwargs={'mode':get_mode('JAX')})
Error message:
AttributeError Traceback (most recent call last)
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/utils.py:202, in streamline.<locals>.streamline_default_f()
199 for thunk, node, old_storage in zip(
200 thunks, order, post_thunk_old_storage
201 ):
--> 202 thunk()
203 for old_s in old_storage:
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/basic.py:669, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
663 def thunk(
664 fgraph=self.fgraph,
665 fgraph_jit=fgraph_jit,
666 thunk_inputs=thunk_inputs,
667 thunk_outputs=thunk_outputs,
668 ):
--> 669 outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
671 for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
[... skipping hidden 12 frame]
File /tmp/tmp3jdhf32h:3, in jax_funcified_fgraph(random_generator_shared_variable)
1 def jax_funcified_fgraph(random_generator_shared_variable):
2 # Second(0.0, 0.0)
----> 3 tensor_variable = second(tensor_constant, tensor_constant_1)
4 # normal_rv{0, (0, 0), floatX, False}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FD6A9F4D8C0>), [], 11, Second.0, 1.0)
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/jax/dispatch/scalar.py:184, in jax_funcify_Second.<locals>.second(x, y)
183 def second(x, y):
--> 184 return jnp.broadcast_to(y, x.shape)
AttributeError: 'float' object has no attribute 'shape'
During handling of the above exception, another exception occurred:
AttributeError Traceback (most recent call last)
Cell In[54], line 3
1 with pm.Model() as mod:
2 x = pm.HalfNormal('x')
----> 3 prior = pm.sample_prior_predictive(compile_kwargs={'mode':get_mode('JAX')})
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pymc/sampling/forward.py:425, in sample_prior_predictive(samples, model, var_names, random_seed, return_inferencedata, idata_kwargs, compile_kwargs)
423 # All model variables have a name, but mypy does not know this
424 _log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore
--> 425 values = zip(*(sampler_fn() for i in range(samples)))
427 data = {k: np.stack(v) for k, v in zip(names, values)}
428 if data is None:
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pymc/sampling/forward.py:425, in <genexpr>(.0)
423 # All model variables have a name, but mypy does not know this
424 _log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore
--> 425 values = zip(*(sampler_fn() for i in range(samples)))
427 data = {k: np.stack(v) for k, v in zip(names, values)}
428 if data is None:
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/compile/function/types.py:970, in Function.__call__(self, *args, **kwargs)
967 t0_fn = time.perf_counter()
968 try:
969 outputs = (
--> 970 self.vm()
971 if output_subset is None
972 else self.vm(output_subset=output_subset)
973 )
974 except Exception:
975 restore_defaults()
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/utils.py:206, in streamline.<locals>.streamline_default_f()
204 old_s[0] = None
205 except Exception:
--> 206 raise_with_op(fgraph, node, thunk)
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/utils.py:535, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
530 warnings.warn(
531 f"{exc_type} error does not allow us to add an extra error message"
532 )
533 # Some exception need extra parameter in inputs. So forget the
534 # extra long error message in that case.
--> 535 raise exc_value.with_traceback(exc_trace)
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/utils.py:202, in streamline.<locals>.streamline_default_f()
198 try:
199 for thunk, node, old_storage in zip(
200 thunks, order, post_thunk_old_storage
201 ):
--> 202 thunk()
203 for old_s in old_storage:
204 old_s[0] = None
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/basic.py:669, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
663 def thunk(
664 fgraph=self.fgraph,
665 fgraph_jit=fgraph_jit,
666 thunk_inputs=thunk_inputs,
667 thunk_outputs=thunk_outputs,
668 ):
--> 669 outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
671 for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
672 compute_map[o_var][0] = True
[... skipping hidden 12 frame]
File /tmp/tmp3jdhf32h:3, in jax_funcified_fgraph(random_generator_shared_variable)
1 def jax_funcified_fgraph(random_generator_shared_variable):
2 # Second(0.0, 0.0)
----> 3 tensor_variable = second(tensor_constant, tensor_constant_1)
4 # normal_rv{0, (0, 0), floatX, False}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FD6A9F4D8C0>), [], 11, Second.0, 1.0)
5 variable, tensor_variable_1 = sample_fn(random_generator_shared_variable, tensor_constant_2, tensor_constant_3, tensor_variable, tensor_constant_4)
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/jax/dispatch/scalar.py:184, in jax_funcify_Second.<locals>.second(x, y)
183 def second(x, y):
--> 184 return jnp.broadcast_to(y, x.shape)
AttributeError: 'float' object has no attribute 'shape'
Apply node that caused the error: Add(Abs.0, 0.0)
Toposort index: 3
Inputs types: [TensorType(float64, shape=()), TensorType(float32, shape=())]
Inputs shapes: ['No shapes']
Inputs strides: ['No strides']
Inputs values: [{'bit_generator': 1, 'state': {'state': -4621532023338195650, 'inc': 8471148850022962065}, 'has_uint32': 0, 'uinteger': 0, 'jax_state': array([3218933020, 1456842046], dtype=uint32)}]
Outputs clients: [['output']]
HINT: Re-running with most PyTensor optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the PyTensor flag 'optimizer=fast_compile'. If that does not work, PyTensor optimizations can be disabled with 'optimizer=None'.
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
PyMC version information:
PyMC Version: 5.5.0
Pytensor version: 2.12.3
Context for the issue:
No response