Skip to content

Commit 1827703

Browse files
juanitorduztwiecki
juanitorduz
authored andcommitted
jax lognormal
1 parent 99510c3 commit 1827703

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

pytensor/link/jax/dispatch/random.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,3 +277,18 @@ def sample_fn(rng, size, dtype, *parameters):
277277
return (rng, sample)
278278

279279
return sample_fn
280+
281+
282+
@jax_sample_fn.register(aer.LogNormalRV)
283+
def jax_sample_fn_lognormal(op):
284+
"""JAX implementation of `LogNormalRV`."""
285+
286+
def sample_fn(rng, size, dtype, *parameters):
287+
rng_key = rng["jax_state"]
288+
loc, scale = parameters
289+
sample = loc + jax.random.normal(rng_key, size, dtype) * scale
290+
sample_exp = jax.numpy.exp(sample)
291+
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
292+
return (rng, sample_exp)
293+
294+
return sample_fn

tests/link/jax/test_random.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,22 @@ def test_random_updates(rng_ctor):
165165
"logistic",
166166
lambda *args: args,
167167
),
168+
(
169+
aer.lognormal,
170+
[
171+
set_test_value(
172+
at.lvector(),
173+
np.array([0, 0], dtype=np.int64),
174+
),
175+
set_test_value(
176+
at.dscalar(),
177+
np.array(1.0, dtype=np.float64),
178+
),
179+
],
180+
(2,),
181+
"lognorm",
182+
lambda *args: args,
183+
),
168184
(
169185
aer.normal,
170186
[

0 commit comments

Comments
 (0)