From e797cd9c7cf5f02aa9f56a27d5e1bdb3d79a0e55 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Tue, 17 Sep 2024 16:21:20 +0530 Subject: [PATCH 1/6] Add jax dispatch for truncated normal --- pymc/dispatch/dispatch_jax.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 pymc/dispatch/dispatch_jax.py diff --git a/pymc/dispatch/dispatch_jax.py b/pymc/dispatch/dispatch_jax.py new file mode 100644 index 0000000000..75cfff0bcf --- /dev/null +++ b/pymc/dispatch/dispatch_jax.py @@ -0,0 +1,28 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import jax + +from pytensor.link.jax.dispatch import jax_funcify + +from pymc.distributions.continous import TruncatedNormalRV + + +@jax_funcify.register(TruncatedNormalRV) +def jax_funcify_TruncatedNormalRV(op, **kwargs): + def trunc_normal_fn(key, size, mu, sigma, lower, upper): + return None, jax.random.truncated_normal( + key["jax_state"], lower=lower, upper=upper, shape=size + ) + + return trunc_normal_fn From 8a9f17eec8a36a234c02a98779eb9aac16d26586 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Tue, 17 Sep 2024 17:15:40 +0530 Subject: [PATCH 2/6] Add test for jax dispatch --- tests/dispatch/test_jax.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 tests/dispatch/test_jax.py diff --git a/tests/dispatch/test_jax.py b/tests/dispatch/test_jax.py new file mode 100644 index 0000000000..67aa695910 --- /dev/null +++ b/tests/dispatch/test_jax.py @@ -0,0 +1,36 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from pytensor import function + +import pymc as pm + +jax = pytest.importorskip("jax", reason="JAX is not installed") + + +def test_jax_TruncatedNormal(): + with pm.Model() as m: + f_jax = function( + [], + [pm.TruncatedNormal("a", 0, 1, lower=-1, upper=2, rng=np.random.default_rng(seed=123))], + mode="JAX", + ) + f_py = function( + [], + [pm.TruncatedNormal("b", 0, 1, lower=-1, upper=2, rng=np.random.default_rng(seed=123))], + ) + + assert jax.numpy.array_equal(a1=f_py(), a2=f_jax()) From fd3b6d2fd67c74e75d9b87a8c2d5c427e165372a Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Fri, 20 Sep 2024 00:49:47 +0530 Subject: [PATCH 3/6] Fix jax dispatch for TruncatedNormal --- pymc/dispatch/dispatch_jax.py | 12 ++++++++---- tests/dispatch/test_jax.py | 7 ++++--- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/pymc/dispatch/dispatch_jax.py b/pymc/dispatch/dispatch_jax.py index 75cfff0bcf..15919e28a1 100644 --- a/pymc/dispatch/dispatch_jax.py +++ b/pymc/dispatch/dispatch_jax.py @@ -15,14 +15,18 @@ from pytensor.link.jax.dispatch import jax_funcify -from pymc.distributions.continous import TruncatedNormalRV +from pymc.distributions.continuous import TruncatedNormalRV @jax_funcify.register(TruncatedNormalRV) def jax_funcify_TruncatedNormalRV(op, **kwargs): def trunc_normal_fn(key, size, mu, sigma, lower, upper): - return None, jax.random.truncated_normal( - key["jax_state"], lower=lower, upper=upper, shape=size - ) + rng_key = key["jax_state"] + rng_key, sampling_key = jax.random.split(rng_key, 2) + key["jax_state"] = rng_key + + truncnorm = jax.nn.initializers.truncated_normal(sigma, lower=lower, upper=upper) + + return key, truncnorm(key["jax_state"], size) + mu return trunc_normal_fn diff --git a/tests/dispatch/test_jax.py b/tests/dispatch/test_jax.py index 67aa695910..1da33c1b0d 100644 --- a/tests/dispatch/test_jax.py +++ b/tests/dispatch/test_jax.py @@ -13,12 +13,13 @@ # limitations under the License. import numpy as np import pytest - +from pymc.dispatch import dispatch_jax from pytensor import function - import pymc as pm -jax = pytest.importorskip("jax", reason="JAX is not installed") + + +jax = pytest.importorskip("jax") def test_jax_TruncatedNormal(): From a7cb67bddff2aa8c6f6e397ba8fa3e01151e1641 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Fri, 20 Sep 2024 01:03:01 +0530 Subject: [PATCH 4/6] Fix ruff error --- tests/dispatch/test_jax.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/dispatch/test_jax.py b/tests/dispatch/test_jax.py index 1da33c1b0d..43c2767c88 100644 --- a/tests/dispatch/test_jax.py +++ b/tests/dispatch/test_jax.py @@ -13,11 +13,12 @@ # limitations under the License. import numpy as np import pytest -from pymc.dispatch import dispatch_jax + from pytensor import function -import pymc as pm +import pymc as pm +from pymc.dispatch import dispatch_jax # noqa: F401 jax = pytest.importorskip("jax") From 1de633c73fbed8827669ea211050928f0dd26202 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Fri, 20 Sep 2024 01:09:19 +0530 Subject: [PATCH 5/6] Fix check-no-tests-are-ignored error --- .github/workflows/tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9ac9eff143..600898979e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -61,6 +61,7 @@ jobs: tests/distributions/test_shape_utils.py tests/distributions/test_mixture.py tests/test_testing.py + tests/dispatch/test_jax.py - | tests/distributions/test_continuous.py From b31dc5d4232b18cb0f9ac0298d168c10a76f7769 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Sat, 12 Oct 2024 23:30:11 +0530 Subject: [PATCH 6/6] Fix failing test and parametrize test --- pymc/dispatch/dispatch_jax.py | 16 ++++++++++++++-- tests/dispatch/test_jax.py | 32 +++++++++++++++++++++++++------- 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/pymc/dispatch/dispatch_jax.py b/pymc/dispatch/dispatch_jax.py index 15919e28a1..694c297363 100644 --- a/pymc/dispatch/dispatch_jax.py +++ b/pymc/dispatch/dispatch_jax.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import jax +import jax.numpy as jnp from pytensor.link.jax.dispatch import jax_funcify @@ -25,8 +26,19 @@ def trunc_normal_fn(key, size, mu, sigma, lower, upper): rng_key, sampling_key = jax.random.split(rng_key, 2) key["jax_state"] = rng_key - truncnorm = jax.nn.initializers.truncated_normal(sigma, lower=lower, upper=upper) + if lower is None: + lower = -jnp.inf + if upper is None: + upper = jnp.inf + else: + new_lower, new_upper = (lower - mu) / sigma, (upper - mu) / sigma - return key, truncnorm(key["jax_state"], size) + mu + if size is None: + size = jnp.broadcast_arrays(jnp.array(mu), jnp.array(sigma))[0].shape + + res = jax.random.truncated_normal(key["jax_state"], new_lower, new_upper, shape=size) + res = res * sigma + mu + + return key, res return trunc_normal_fn diff --git a/tests/dispatch/test_jax.py b/tests/dispatch/test_jax.py index 43c2767c88..0ff536eec7 100644 --- a/tests/dispatch/test_jax.py +++ b/tests/dispatch/test_jax.py @@ -23,16 +23,34 @@ jax = pytest.importorskip("jax") -def test_jax_TruncatedNormal(): +@pytest.mark.parametrize("sigma", [0.02, 5]) +def test_jax_TruncatedNormal(sigma): with pm.Model() as m: + lower = 5 + upper = 8 + mu = 6 + + a = pm.TruncatedNormal( + "a", mu, sigma, lower=lower, upper=upper, rng=np.random.default_rng(seed=123) + ) + f_jax = function( [], - [pm.TruncatedNormal("a", 0, 1, lower=-1, upper=2, rng=np.random.default_rng(seed=123))], + [ + pm.TruncatedNormal( + "b", + mu, + sigma, + lower=lower, + upper=upper, + rng=np.random.default_rng(seed=123), + ) + ], mode="JAX", ) - f_py = function( - [], - [pm.TruncatedNormal("b", 0, 1, lower=-1, upper=2, rng=np.random.default_rng(seed=123))], - ) + res = f_jax() + + draws = pm.draw(a, draws=100, mode="JAX") - assert jax.numpy.array_equal(a1=f_py(), a2=f_jax()) + assert jax.numpy.all((draws >= lower) & (draws <= upper)) + assert jax.numpy.all((res[0] >= lower) & (res[0] <= upper))