@@ -125,8 +125,9 @@ def jax_sample_fn_generic(op):
125
125
126
126
def sample_fn (rng , size , dtype , * parameters ):
127
127
rng_key = rng ["jax_state" ]
128
- sample = jax_op (rng_key , * parameters , shape = size , dtype = dtype )
129
- rng ["jax_state" ] = jax .random .split (rng_key , num = 1 )[0 ]
128
+ rng_key , sampling_key = jax .random .split (rng_key , 2 )
129
+ sample = jax_op (sampling_key , * parameters , shape = size , dtype = dtype )
130
+ rng ["jax_state" ] = rng_key
130
131
return (rng , sample )
131
132
132
133
return sample_fn
@@ -151,9 +152,10 @@ def jax_sample_fn_loc_scale(op):
151
152
152
153
def sample_fn (rng , size , dtype , * parameters ):
153
154
rng_key = rng ["jax_state" ]
155
+ rng_key , sampling_key = jax .random .split (rng_key , 2 )
154
156
loc , scale = parameters
155
- sample = loc + jax_op (rng_key , size , dtype ) * scale
156
- rng ["jax_state" ] = jax . random . split ( rng_key , num = 1 )[ 0 ]
157
+ sample = loc + jax_op (sampling_key , size , dtype ) * scale
158
+ rng ["jax_state" ] = rng_key
157
159
return (rng , sample )
158
160
159
161
return sample_fn
@@ -168,8 +170,9 @@ def jax_sample_fn_no_dtype(op):
168
170
169
171
def sample_fn (rng , size , dtype , * parameters ):
170
172
rng_key = rng ["jax_state" ]
171
- sample = jax_op (rng_key , * parameters , shape = size )
172
- rng ["jax_state" ] = jax .random .split (rng_key , num = 1 )[0 ]
173
+ rng_key , sampling_key = jax .random .split (rng_key , 2 )
174
+ sample = jax_op (sampling_key , * parameters , shape = size )
175
+ rng ["jax_state" ] = rng_key
173
176
return (rng , sample )
174
177
175
178
return sample_fn
@@ -189,9 +192,12 @@ def jax_sample_fn_uniform(op):
189
192
190
193
def sample_fn (rng , size , dtype , * parameters ):
191
194
rng_key = rng ["jax_state" ]
195
+ rng_key , sampling_key = jax .random .split (rng_key , 2 )
192
196
minval , maxval = parameters
193
- sample = jax_op (rng_key , shape = size , dtype = dtype , minval = minval , maxval = maxval )
194
- rng ["jax_state" ] = jax .random .split (rng_key , num = 1 )[0 ]
197
+ sample = jax_op (
198
+ sampling_key , shape = size , dtype = dtype , minval = minval , maxval = maxval
199
+ )
200
+ rng ["jax_state" ] = rng_key
195
201
return (rng , sample )
196
202
197
203
return sample_fn
@@ -211,9 +217,10 @@ def jax_sample_fn_shape_rate(op):
211
217
212
218
def sample_fn (rng , size , dtype , * parameters ):
213
219
rng_key = rng ["jax_state" ]
220
+ rng_key , sampling_key = jax .random .split (rng_key , 2 )
214
221
(shape , rate ) = parameters
215
- sample = jax_op (rng_key , shape , size , dtype ) / rate
216
- rng ["jax_state" ] = jax . random . split ( rng_key , num = 1 )[ 0 ]
222
+ sample = jax_op (sampling_key , shape , size , dtype ) / rate
223
+ rng ["jax_state" ] = rng_key
217
224
return (rng , sample )
218
225
219
226
return sample_fn
@@ -225,9 +232,10 @@ def jax_sample_fn_exponential(op):
225
232
226
233
def sample_fn (rng , size , dtype , * parameters ):
227
234
rng_key = rng ["jax_state" ]
235
+ rng_key , sampling_key = jax .random .split (rng_key , 2 )
228
236
(scale ,) = parameters
229
- sample = jax .random .exponential (rng_key , size , dtype ) * scale
230
- rng ["jax_state" ] = jax . random . split ( rng_key , num = 1 )[ 0 ]
237
+ sample = jax .random .exponential (sampling_key , size , dtype ) * scale
238
+ rng ["jax_state" ] = rng_key
231
239
return (rng , sample )
232
240
233
241
return sample_fn
@@ -239,13 +247,14 @@ def jax_sample_fn_t(op):
239
247
240
248
def sample_fn (rng , size , dtype , * parameters ):
241
249
rng_key = rng ["jax_state" ]
250
+ rng_key , sampling_key = jax .random .split (rng_key , 2 )
242
251
(
243
252
df ,
244
253
loc ,
245
254
scale ,
246
255
) = parameters
247
- sample = loc + jax .random .t (rng_key , df , size , dtype ) * scale
248
- rng ["jax_state" ] = jax . random . split ( rng_key , num = 1 )[ 0 ]
256
+ sample = loc + jax .random .t (sampling_key , df , size , dtype ) * scale
257
+ rng ["jax_state" ] = rng_key
249
258
return (rng , sample )
250
259
251
260
return sample_fn
@@ -257,9 +266,10 @@ def jax_funcify_choice(op):
257
266
258
267
def sample_fn (rng , size , dtype , * parameters ):
259
268
rng_key = rng ["jax_state" ]
269
+ rng_key , sampling_key = jax .random .split (rng_key , 2 )
260
270
(a , p , replace ) = parameters
261
- smpl_value = jax .random .choice (rng_key , a , size , replace , p )
262
- rng ["jax_state" ] = jax . random . split ( rng_key , num = 1 )[ 0 ]
271
+ smpl_value = jax .random .choice (sampling_key , a , size , replace , p )
272
+ rng ["jax_state" ] = rng_key
263
273
return (rng , smpl_value )
264
274
265
275
return sample_fn
@@ -271,9 +281,10 @@ def jax_sample_fn_permutation(op):
271
281
272
282
def sample_fn (rng , size , dtype , * parameters ):
273
283
rng_key = rng ["jax_state" ]
284
+ rng_key , sampling_key = jax .random .split (rng_key , 2 )
274
285
(x ,) = parameters
275
- sample = jax .random .permutation (rng_key , x )
276
- rng ["jax_state" ] = jax . random . split ( rng_key , num = 1 )[ 0 ]
286
+ sample = jax .random .permutation (sampling_key , x )
287
+ rng ["jax_state" ] = rng_key
277
288
return (rng , sample )
278
289
279
290
return sample_fn
@@ -285,10 +296,11 @@ def jax_sample_fn_lognormal(op):
285
296
286
297
def sample_fn (rng , size , dtype , * parameters ):
287
298
rng_key = rng ["jax_state" ]
299
+ rng_key , sampling_key = jax .random .split (rng_key , 2 )
288
300
loc , scale = parameters
289
- sample = loc + jax .random .normal (rng_key , size , dtype ) * scale
301
+ sample = loc + jax .random .normal (sampling_key , size , dtype ) * scale
290
302
sample_exp = jax .numpy .exp (sample )
291
- rng ["jax_state" ] = jax . random . split ( rng_key , num = 1 )[ 0 ]
303
+ rng ["jax_state" ] = rng_key
292
304
return (rng , sample_exp )
293
305
294
306
return sample_fn
0 commit comments