Skip to content

Commit 4eded29

Browse files
twieckitheorashidricardoV94
authored
🔄 From Aesara: 1362: "Add HalfNormalRV JAX implementation" (#129)
* Add `HalfNormalRV` JAX implementation (#1362) Co-authored-by: theorashid <theoaorashid@gmail.com> Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
1 parent e2ae99e commit 4eded29

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

pytensor/link/jax/dispatch/random.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,23 @@ def sample_fn(rng, size, dtype, *parameters):
260260
return sample_fn
261261

262262

263+
@jax_sample_fn.register(aer.HalfNormalRV)
264+
def jax_sample_fn_halfnormal(op):
265+
"""JAX implementation of `HalfNormalRV`."""
266+
267+
def sample_fn(rng, size, dtype, *parameters):
268+
rng_key = rng["jax_state"]
269+
rng_key, sampling_key = jax.random.split(rng_key, 2)
270+
loc, scale = parameters
271+
sample = (
272+
loc + jax.numpy.abs(jax.random.normal(sampling_key, size, dtype)) * scale
273+
)
274+
rng["jax_state"] = rng_key
275+
return (rng, sample)
276+
277+
return sample_fn
278+
279+
263280
@jax_sample_fn.register(aer.ChoiceRV)
264281
def jax_funcify_choice(op):
265282
"""JAX implementation of `ChoiceRV`."""

tests/link/jax/test_random.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,22 @@ def test_random_updates(rng_ctor):
280280
"uniform",
281281
lambda *args: args,
282282
),
283+
(
284+
aer.halfnormal,
285+
[
286+
set_test_value(
287+
at.dvector(),
288+
np.array([-1.0, 2.0], dtype=np.float64),
289+
),
290+
set_test_value(
291+
at.dscalar(),
292+
np.array(1000.0, dtype=np.float64),
293+
),
294+
],
295+
(2,),
296+
"halfnorm",
297+
lambda *args: args,
298+
),
283299
],
284300
)
285301
def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_conv):

0 commit comments

Comments
 (0)