Skip to content

Commit b626689

Browse files
AdrienCorenflostwiecki
authored andcommitted
Split RNG keys before using them in JAX backend
1 parent 5c63ee7 commit b626689

File tree

1 file changed

+32
-20
lines changed

1 file changed

+32
-20
lines changed

pytensor/link/jax/dispatch/random.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,9 @@ def jax_sample_fn_generic(op):
125125

126126
def sample_fn(rng, size, dtype, *parameters):
127127
rng_key = rng["jax_state"]
128-
sample = jax_op(rng_key, *parameters, shape=size, dtype=dtype)
129-
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
128+
rng_key, sampling_key = jax.random.split(rng_key, 2)
129+
sample = jax_op(sampling_key, *parameters, shape=size, dtype=dtype)
130+
rng["jax_state"] = rng_key
130131
return (rng, sample)
131132

132133
return sample_fn
@@ -151,9 +152,10 @@ def jax_sample_fn_loc_scale(op):
151152

152153
def sample_fn(rng, size, dtype, *parameters):
153154
rng_key = rng["jax_state"]
155+
rng_key, sampling_key = jax.random.split(rng_key, 2)
154156
loc, scale = parameters
155-
sample = loc + jax_op(rng_key, size, dtype) * scale
156-
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
157+
sample = loc + jax_op(sampling_key, size, dtype) * scale
158+
rng["jax_state"] = rng_key
157159
return (rng, sample)
158160

159161
return sample_fn
@@ -168,8 +170,9 @@ def jax_sample_fn_no_dtype(op):
168170

169171
def sample_fn(rng, size, dtype, *parameters):
170172
rng_key = rng["jax_state"]
171-
sample = jax_op(rng_key, *parameters, shape=size)
172-
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
173+
rng_key, sampling_key = jax.random.split(rng_key, 2)
174+
sample = jax_op(sampling_key, *parameters, shape=size)
175+
rng["jax_state"] = rng_key
173176
return (rng, sample)
174177

175178
return sample_fn
@@ -189,9 +192,12 @@ def jax_sample_fn_uniform(op):
189192

190193
def sample_fn(rng, size, dtype, *parameters):
191194
rng_key = rng["jax_state"]
195+
rng_key, sampling_key = jax.random.split(rng_key, 2)
192196
minval, maxval = parameters
193-
sample = jax_op(rng_key, shape=size, dtype=dtype, minval=minval, maxval=maxval)
194-
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
197+
sample = jax_op(
198+
sampling_key, shape=size, dtype=dtype, minval=minval, maxval=maxval
199+
)
200+
rng["jax_state"] = rng_key
195201
return (rng, sample)
196202

197203
return sample_fn
@@ -211,9 +217,10 @@ def jax_sample_fn_shape_rate(op):
211217

212218
def sample_fn(rng, size, dtype, *parameters):
213219
rng_key = rng["jax_state"]
220+
rng_key, sampling_key = jax.random.split(rng_key, 2)
214221
(shape, rate) = parameters
215-
sample = jax_op(rng_key, shape, size, dtype) / rate
216-
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
222+
sample = jax_op(sampling_key, shape, size, dtype) / rate
223+
rng["jax_state"] = rng_key
217224
return (rng, sample)
218225

219226
return sample_fn
@@ -225,9 +232,10 @@ def jax_sample_fn_exponential(op):
225232

226233
def sample_fn(rng, size, dtype, *parameters):
227234
rng_key = rng["jax_state"]
235+
rng_key, sampling_key = jax.random.split(rng_key, 2)
228236
(scale,) = parameters
229-
sample = jax.random.exponential(rng_key, size, dtype) * scale
230-
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
237+
sample = jax.random.exponential(sampling_key, size, dtype) * scale
238+
rng["jax_state"] = rng_key
231239
return (rng, sample)
232240

233241
return sample_fn
@@ -239,13 +247,14 @@ def jax_sample_fn_t(op):
239247

240248
def sample_fn(rng, size, dtype, *parameters):
241249
rng_key = rng["jax_state"]
250+
rng_key, sampling_key = jax.random.split(rng_key, 2)
242251
(
243252
df,
244253
loc,
245254
scale,
246255
) = parameters
247-
sample = loc + jax.random.t(rng_key, df, size, dtype) * scale
248-
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
256+
sample = loc + jax.random.t(sampling_key, df, size, dtype) * scale
257+
rng["jax_state"] = rng_key
249258
return (rng, sample)
250259

251260
return sample_fn
@@ -257,9 +266,10 @@ def jax_funcify_choice(op):
257266

258267
def sample_fn(rng, size, dtype, *parameters):
259268
rng_key = rng["jax_state"]
269+
rng_key, sampling_key = jax.random.split(rng_key, 2)
260270
(a, p, replace) = parameters
261-
smpl_value = jax.random.choice(rng_key, a, size, replace, p)
262-
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
271+
smpl_value = jax.random.choice(sampling_key, a, size, replace, p)
272+
rng["jax_state"] = rng_key
263273
return (rng, smpl_value)
264274

265275
return sample_fn
@@ -271,9 +281,10 @@ def jax_sample_fn_permutation(op):
271281

272282
def sample_fn(rng, size, dtype, *parameters):
273283
rng_key = rng["jax_state"]
284+
rng_key, sampling_key = jax.random.split(rng_key, 2)
274285
(x,) = parameters
275-
sample = jax.random.permutation(rng_key, x)
276-
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
286+
sample = jax.random.permutation(sampling_key, x)
287+
rng["jax_state"] = rng_key
277288
return (rng, sample)
278289

279290
return sample_fn
@@ -285,10 +296,11 @@ def jax_sample_fn_lognormal(op):
285296

286297
def sample_fn(rng, size, dtype, *parameters):
287298
rng_key = rng["jax_state"]
299+
rng_key, sampling_key = jax.random.split(rng_key, 2)
288300
loc, scale = parameters
289-
sample = loc + jax.random.normal(rng_key, size, dtype) * scale
301+
sample = loc + jax.random.normal(sampling_key, size, dtype) * scale
290302
sample_exp = jax.numpy.exp(sample)
291-
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
303+
rng["jax_state"] = rng_key
292304
return (rng, sample_exp)
293305

294306
return sample_fn

0 commit comments

Comments
 (0)