Skip to content

Commit dcd24a3

Browse files
committed
Implement JAX dispatch for IntegersRV
1 parent bf51907 commit dcd24a3

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

pytensor/link/jax/dispatch/random.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def sample_fn(rng, size, dtype, *parameters):
179179

180180

181181
@jax_sample_fn.register(aer.RandIntRV)
182+
@jax_sample_fn.register(aer.IntegersRV)
182183
@jax_sample_fn.register(aer.UniformRV)
183184
def jax_sample_fn_uniform(op):
184185
"""JAX implementation of random variables with uniform density.
@@ -188,6 +189,9 @@ def jax_sample_fn_uniform(op):
188189
189190
"""
190191
name = op.name
192+
# IntegersRV is equivalent to RandintRV
193+
if isinstance(op, aer.IntegersRV):
194+
name = "randint"
191195
jax_op = getattr(jax.random, name)
192196

193197
def sample_fn(rng, size, dtype, *parameters):

tests/link/jax/test_random.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,22 @@ def test_random_updates(rng_ctor):
237237
"randint",
238238
lambda *args: args,
239239
),
240+
(
241+
aer.integers,
242+
[
243+
set_test_value(
244+
at.lscalar(),
245+
np.array(0, dtype=np.int64),
246+
),
247+
set_test_value( # high-value necessary since test on cdf
248+
at.lscalar(),
249+
np.array(1000, dtype=np.int64),
250+
),
251+
],
252+
(),
253+
"randint",
254+
lambda *args: args,
255+
),
240256
(
241257
aer.standard_normal,
242258
[],
@@ -376,7 +392,11 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
376392
The parameters passed to the op.
377393
378394
"""
379-
rng = shared(np.random.RandomState(29402))
395+
if rv_op is aer.integers:
396+
# Integers only accepts Generator, not RandomState
397+
rng = shared(np.random.default_rng(29402))
398+
else:
399+
rng = shared(np.random.RandomState(29402))
380400
g = rv_op(*dist_params, size=(10_000,) + base_size, rng=rng)
381401
g_fn = function(dist_params, g, mode=jax_mode)
382402
samples = g_fn(

0 commit comments

Comments
 (0)