diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index 002d0758e6..5326db44ff 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -251,6 +251,23 @@ def sample_fn(rng, size, dtype, *parameters): return sample_fn +@jax_sample_fn.register(aer.HalfNormalRV) +def jax_sample_fn_halfnormal(op): + """JAX implementation of `HalfNormalRV`.""" + + def sample_fn(rng, size, dtype, *parameters): + rng_key = rng["jax_state"] + rng_key, sampling_key = jax.random.split(rng_key, 2) + loc, scale = parameters + sample = ( + loc + jax.numpy.abs(jax.random.normal(sampling_key, size, dtype)) * scale + ) + rng["jax_state"] = rng_key + return (rng, sample) + + return sample_fn + + @jax_sample_fn.register(aer.ChoiceRV) def jax_funcify_choice(op): """JAX implementation of `ChoiceRV`.""" diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index 9dd5e61772..b106721688 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -280,6 +280,22 @@ def test_random_updates(rng_ctor): "uniform", lambda *args: args, ), + ( + aer.halfnormal, + [ + set_test_value( + at.dvector(), + np.array([-1.0, 2.0], dtype=np.float64), + ), + set_test_value( + at.dscalar(), + np.array(1000.0, dtype=np.float64), + ), + ], + (2,), + "halfnorm", + lambda *args: args, + ), ], ) def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_conv):