Skip to content

HalfNormal in JAX failing due to implicit downcasting of constant 0d TensorVariable to float #373

Closed
@jessegrabowski

Description

@jessegrabowski

Describe the issue:

You can't forward sample from a half-normal distribution in JAX mode

Reproduceable code example:

import pymc as pm
from pymc.pytensorf import get_mode

with pm.Model() as mod:
    x = pm.HalfNormal('x')
    prior = pm.sample_prior_predictive(compile_kwargs={'mode':get_mode('JAX')})

Error message:

AttributeError                            Traceback (most recent call last)
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/utils.py:202, in streamline.<locals>.streamline_default_f()
    199 for thunk, node, old_storage in zip(
    200     thunks, order, post_thunk_old_storage
    201 ):
--> 202     thunk()
    203     for old_s in old_storage:

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/basic.py:669, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
    663 def thunk(
    664     fgraph=self.fgraph,
    665     fgraph_jit=fgraph_jit,
    666     thunk_inputs=thunk_inputs,
    667     thunk_outputs=thunk_outputs,
    668 ):
--> 669     outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
    671     for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):

    [... skipping hidden 12 frame]

File /tmp/tmp3jdhf32h:3, in jax_funcified_fgraph(random_generator_shared_variable)
      1 def jax_funcified_fgraph(random_generator_shared_variable):
      2     # Second(0.0, 0.0)
----> 3     tensor_variable = second(tensor_constant, tensor_constant_1)
      4     # normal_rv{0, (0, 0), floatX, False}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FD6A9F4D8C0>), [], 11, Second.0, 1.0)

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/jax/dispatch/scalar.py:184, in jax_funcify_Second.<locals>.second(x, y)
    183 def second(x, y):
--> 184     return jnp.broadcast_to(y, x.shape)

AttributeError: 'float' object has no attribute 'shape'

During handling of the above exception, another exception occurred:

AttributeError                            Traceback (most recent call last)
Cell In[54], line 3
      1 with pm.Model() as mod:
      2     x = pm.HalfNormal('x')
----> 3     prior = pm.sample_prior_predictive(compile_kwargs={'mode':get_mode('JAX')})

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pymc/sampling/forward.py:425, in sample_prior_predictive(samples, model, var_names, random_seed, return_inferencedata, idata_kwargs, compile_kwargs)
    423 # All model variables have a name, but mypy does not know this
    424 _log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}")  # type: ignore
--> 425 values = zip(*(sampler_fn() for i in range(samples)))
    427 data = {k: np.stack(v) for k, v in zip(names, values)}
    428 if data is None:

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pymc/sampling/forward.py:425, in <genexpr>(.0)
    423 # All model variables have a name, but mypy does not know this
    424 _log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}")  # type: ignore
--> 425 values = zip(*(sampler_fn() for i in range(samples)))
    427 data = {k: np.stack(v) for k, v in zip(names, values)}
    428 if data is None:

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/compile/function/types.py:970, in Function.__call__(self, *args, **kwargs)
    967 t0_fn = time.perf_counter()
    968 try:
    969     outputs = (
--> 970         self.vm()
    971         if output_subset is None
    972         else self.vm(output_subset=output_subset)
    973     )
    974 except Exception:
    975     restore_defaults()

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/utils.py:206, in streamline.<locals>.streamline_default_f()
    204             old_s[0] = None
    205 except Exception:
--> 206     raise_with_op(fgraph, node, thunk)

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/utils.py:535, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    530     warnings.warn(
    531         f"{exc_type} error does not allow us to add an extra error message"
    532     )
    533     # Some exception need extra parameter in inputs. So forget the
    534     # extra long error message in that case.
--> 535 raise exc_value.with_traceback(exc_trace)

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/utils.py:202, in streamline.<locals>.streamline_default_f()
    198 try:
    199     for thunk, node, old_storage in zip(
    200         thunks, order, post_thunk_old_storage
    201     ):
--> 202         thunk()
    203         for old_s in old_storage:
    204             old_s[0] = None

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/basic.py:669, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
    663 def thunk(
    664     fgraph=self.fgraph,
    665     fgraph_jit=fgraph_jit,
    666     thunk_inputs=thunk_inputs,
    667     thunk_outputs=thunk_outputs,
    668 ):
--> 669     outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
    671     for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
    672         compute_map[o_var][0] = True

    [... skipping hidden 12 frame]

File /tmp/tmp3jdhf32h:3, in jax_funcified_fgraph(random_generator_shared_variable)
      1 def jax_funcified_fgraph(random_generator_shared_variable):
      2     # Second(0.0, 0.0)
----> 3     tensor_variable = second(tensor_constant, tensor_constant_1)
      4     # normal_rv{0, (0, 0), floatX, False}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FD6A9F4D8C0>), [], 11, Second.0, 1.0)
      5     variable, tensor_variable_1 = sample_fn(random_generator_shared_variable, tensor_constant_2, tensor_constant_3, tensor_variable, tensor_constant_4)

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/jax/dispatch/scalar.py:184, in jax_funcify_Second.<locals>.second(x, y)
    183 def second(x, y):
--> 184     return jnp.broadcast_to(y, x.shape)

AttributeError: 'float' object has no attribute 'shape'
Apply node that caused the error: Add(Abs.0, 0.0)
Toposort index: 3
Inputs types: [TensorType(float64, shape=()), TensorType(float32, shape=())]
Inputs shapes: ['No shapes']
Inputs strides: ['No strides']
Inputs values: [{'bit_generator': 1, 'state': {'state': -4621532023338195650, 'inc': 8471148850022962065}, 'has_uint32': 0, 'uinteger': 0, 'jax_state': array([3218933020, 1456842046], dtype=uint32)}]
Outputs clients: [['output']]

HINT: Re-running with most PyTensor optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the PyTensor flag 'optimizer=fast_compile'. If that does not work, PyTensor optimizations can be disabled with 'optimizer=None'.
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

PyMC version information:

PyMC Version: 5.5.0 Pytensor version: 2.12.3

Context for the issue:

No response

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions