From a5e087e406bd7eb31aeafc62af8c4dc8b2a24acc Mon Sep 17 00:00:00 2001 From: theorashid Date: Mon, 12 Dec 2022 10:04:41 +0000 Subject: [PATCH 1/6] Add `HalfNormalRV` JAX implementation (#1362) --- pytensor/link/jax/dispatch/random.py | 22 ++++++++++++++++++++++ tests/link/jax/test_random.py | 16 ++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index 002d0758e6..15466b7733 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -251,6 +251,28 @@ def sample_fn(rng, size, dtype, *parameters): return sample_fn +@jax_sample_fn.register(aer.HalfNormalRV) +def jax_sample_fn_halfnormal(op): + """JAX implementation of `HalfNormalRV`.""" + + def sample_fn(rng, size, dtype, *parameters): + rng_key = rng["jax_state"] + rng_key, sampling_key = jax.random.split(rng_key, 2) + ( + loc, + scale, + ) = parameters + sample = ( + loc + + jax.random.truncated_normal(sampling_key, 0.0, jax.numpy.inf, size, dtype) + * scale + ) + rng["jax_state"] = rng_key + return (rng, sample) + + return sample_fn + + @jax_sample_fn.register(aer.ChoiceRV) def jax_funcify_choice(op): """JAX implementation of `ChoiceRV`.""" diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index 9dd5e61772..b106721688 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -280,6 +280,22 @@ def test_random_updates(rng_ctor): "uniform", lambda *args: args, ), + ( + aer.halfnormal, + [ + set_test_value( + at.dvector(), + np.array([-1.0, 2.0], dtype=np.float64), + ), + set_test_value( + at.dscalar(), + np.array(1000.0, dtype=np.float64), + ), + ], + (2,), + "halfnorm", + lambda *args: args, + ), ], ) def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_conv): From 7b7bc6d3830acfed37dad4a8025465ae4d3206bd Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Wed, 14 Dec 2022 16:35:04 +0100 Subject: [PATCH 2/6] Update pytensor/link/jax/dispatch/random.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pytensor/link/jax/dispatch/random.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index 15466b7733..0114d5df36 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -264,7 +264,7 @@ def sample_fn(rng, size, dtype, *parameters): ) = parameters sample = ( loc - + jax.random.truncated_normal(sampling_key, 0.0, jax.numpy.inf, size, dtype) + + jax.numpy.abs(jax.random.normal(sampling_key, size, dtype)) * scale ) rng["jax_state"] = rng_key From 88e88438ef3441145cc838454ca16adb5db8263f Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Wed, 14 Dec 2022 18:09:46 +0100 Subject: [PATCH 3/6] Fix formatting. --- pytensor/link/jax/dispatch/random.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index 0114d5df36..05e3a82212 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -262,11 +262,7 @@ def sample_fn(rng, size, dtype, *parameters): loc, scale, ) = parameters - sample = ( - loc - + jax.numpy.abs(jax.random.normal(sampling_key, size, dtype)) - * scale - ) + sample = loc + jax.numpy.abs(jax.random.normal(sampling_key, size, dtype) * scale rng["jax_state"] = rng_key return (rng, sample) From 6478b10e8b2ef76f8beacfe79551651bab0abdd0 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Wed, 14 Dec 2022 22:36:24 +0100 Subject: [PATCH 4/6] Reformat. --- pytensor/link/jax/dispatch/random.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index 05e3a82212..19c72745bc 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -258,10 +258,7 @@ def jax_sample_fn_halfnormal(op): def sample_fn(rng, size, dtype, *parameters): rng_key = rng["jax_state"] rng_key, sampling_key = jax.random.split(rng_key, 2) - ( - loc, - scale, - ) = parameters + loc, scale = parameters sample = loc + jax.numpy.abs(jax.random.normal(sampling_key, size, dtype) * scale rng["jax_state"] = rng_key return (rng, sample) From d563ecd25ebb60226a25b8e0f2bf741e7819d931 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Wed, 14 Dec 2022 22:36:53 +0100 Subject: [PATCH 5/6] Fix missing paranthesis. --- pytensor/link/jax/dispatch/random.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index 19c72745bc..fc825700c6 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -259,7 +259,7 @@ def sample_fn(rng, size, dtype, *parameters): rng_key = rng["jax_state"] rng_key, sampling_key = jax.random.split(rng_key, 2) loc, scale = parameters - sample = loc + jax.numpy.abs(jax.random.normal(sampling_key, size, dtype) * scale + sample = loc + jax.numpy.abs(jax.random.normal(sampling_key, size, dtype)) * scale rng["jax_state"] = rng_key return (rng, sample) From fac9a48cad01d08377ca25f76c421057a2ea8db0 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Thu, 15 Dec 2022 10:12:06 +0100 Subject: [PATCH 6/6] Fix syntax error. --- pytensor/link/jax/dispatch/random.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index fc825700c6..5326db44ff 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -259,7 +259,9 @@ def sample_fn(rng, size, dtype, *parameters): rng_key = rng["jax_state"] rng_key, sampling_key = jax.random.split(rng_key, 2) loc, scale = parameters - sample = loc + jax.numpy.abs(jax.random.normal(sampling_key, size, dtype)) * scale + sample = ( + loc + jax.numpy.abs(jax.random.normal(sampling_key, size, dtype)) * scale + ) rng["jax_state"] = rng_key return (rng, sample)