10
10
import pytensor .tensor .random .basic as aer
11
11
from pytensor .link .jax .dispatch .basic import jax_funcify , jax_typify
12
12
from pytensor .link .jax .dispatch .shape import JAXShapeTuple
13
+ from pytensor .tensor .random .type import RandomType
13
14
from pytensor .tensor .shape import Shape , Shape_i
14
15
15
16
@@ -55,8 +56,7 @@ def jax_typify_RandomState(state, **kwargs):
55
56
state = state .get_state (legacy = False )
56
57
state ["bit_generator" ] = numpy_bit_gens [state ["bit_generator" ]]
57
58
# XXX: Is this a reasonable approach?
58
- state ["jax_state" ] = state ["state" ]["key" ][0 :2 ]
59
- return state
59
+ return state ["state" ]["key" ][0 :2 ]
60
60
61
61
62
62
@jax_typify .register (Generator )
@@ -81,7 +81,36 @@ def jax_typify_Generator(rng, **kwargs):
81
81
state_32 = _coerce_to_uint32_array (state ["state" ]["state" ])
82
82
state ["state" ]["inc" ] = inc_32 [0 ] << 32 | inc_32 [1 ]
83
83
state ["state" ]["state" ] = state_32 [0 ] << 32 | state_32 [1 ]
84
- return state
84
+ return state ["jax_state" ]
85
+
86
+
87
+ class RandomPRNGKeyType (RandomType [jax .random .PRNGKey ]):
88
+ """JAX-compatible PRNGKey type.
89
+
90
+ This type is not exposed to users directly.
91
+
92
+ It is introduced by the JIT linker in place of any RandomStateType or RandomGeneratorType
93
+ input variables used in the original function. Nodes in the function graph will
94
+ still show the original types as inputs and outputs.
95
+ """
96
+
97
+ def filter (self , data , strict : bool = False , allow_downcast = None ):
98
+ # PRNGs are just JAX Arrays, we assume this is a valid one!
99
+ if isinstance (data , jax .Array ):
100
+ return data
101
+
102
+ if strict :
103
+ raise TypeError ()
104
+
105
+ return jax_typify (data )
106
+
107
+
108
+ random_prng_key_type = RandomPRNGKeyType ()
109
+
110
+
111
+ @jax_typify .register (RandomType )
112
+ def jax_typify_RandomType (type ):
113
+ return random_prng_key_type ()
85
114
86
115
87
116
@jax_funcify .register (aer .RandomVariable )
@@ -128,12 +157,10 @@ def jax_sample_fn_generic(op):
128
157
name = op .name
129
158
jax_op = getattr (jax .random , name )
130
159
131
- def sample_fn (rng , size , dtype , * parameters ):
132
- rng_key = rng ["jax_state" ]
160
+ def sample_fn (rng_key , size , dtype , * parameters ):
133
161
rng_key , sampling_key = jax .random .split (rng_key , 2 )
134
162
sample = jax_op (sampling_key , * parameters , shape = size , dtype = dtype )
135
- rng ["jax_state" ] = rng_key
136
- return (rng , sample )
163
+ return (rng_key , sample )
137
164
138
165
return sample_fn
139
166
@@ -155,13 +182,11 @@ def jax_sample_fn_loc_scale(op):
155
182
name = op .name
156
183
jax_op = getattr (jax .random , name )
157
184
158
- def sample_fn (rng , size , dtype , * parameters ):
159
- rng_key = rng ["jax_state" ]
185
+ def sample_fn (rng_key , size , dtype , * parameters ):
160
186
rng_key , sampling_key = jax .random .split (rng_key , 2 )
161
187
loc , scale = parameters
162
188
sample = loc + jax_op (sampling_key , size , dtype ) * scale
163
- rng ["jax_state" ] = rng_key
164
- return (rng , sample )
189
+ return (rng_key , sample )
165
190
166
191
return sample_fn
167
192
@@ -173,12 +198,10 @@ def jax_sample_fn_no_dtype(op):
173
198
name = op .name
174
199
jax_op = getattr (jax .random , name )
175
200
176
- def sample_fn (rng , size , dtype , * parameters ):
177
- rng_key = rng ["jax_state" ]
201
+ def sample_fn (rng_key , size , dtype , * parameters ):
178
202
rng_key , sampling_key = jax .random .split (rng_key , 2 )
179
203
sample = jax_op (sampling_key , * parameters , shape = size )
180
- rng ["jax_state" ] = rng_key
181
- return (rng , sample )
204
+ return (rng_key , sample )
182
205
183
206
return sample_fn
184
207
@@ -199,15 +222,13 @@ def jax_sample_fn_uniform(op):
199
222
name = "randint"
200
223
jax_op = getattr (jax .random , name )
201
224
202
- def sample_fn (rng , size , dtype , * parameters ):
203
- rng_key = rng ["jax_state" ]
225
+ def sample_fn (rng_key , size , dtype , * parameters ):
204
226
rng_key , sampling_key = jax .random .split (rng_key , 2 )
205
227
minval , maxval = parameters
206
228
sample = jax_op (
207
229
sampling_key , shape = size , dtype = dtype , minval = minval , maxval = maxval
208
230
)
209
- rng ["jax_state" ] = rng_key
210
- return (rng , sample )
231
+ return (rng_key , sample )
211
232
212
233
return sample_fn
213
234
@@ -224,13 +245,11 @@ def jax_sample_fn_shape_rate(op):
224
245
name = op .name
225
246
jax_op = getattr (jax .random , name )
226
247
227
- def sample_fn (rng , size , dtype , * parameters ):
228
- rng_key = rng ["jax_state" ]
248
+ def sample_fn (rng_key , size , dtype , * parameters ):
229
249
rng_key , sampling_key = jax .random .split (rng_key , 2 )
230
250
(shape , rate ) = parameters
231
251
sample = jax_op (sampling_key , shape , size , dtype ) / rate
232
- rng ["jax_state" ] = rng_key
233
- return (rng , sample )
252
+ return (rng_key , sample )
234
253
235
254
return sample_fn
236
255
@@ -239,13 +258,11 @@ def sample_fn(rng, size, dtype, *parameters):
239
258
def jax_sample_fn_exponential (op ):
240
259
"""JAX implementation of `ExponentialRV`."""
241
260
242
- def sample_fn (rng , size , dtype , * parameters ):
243
- rng_key = rng ["jax_state" ]
261
+ def sample_fn (rng_key , size , dtype , * parameters ):
244
262
rng_key , sampling_key = jax .random .split (rng_key , 2 )
245
263
(scale ,) = parameters
246
264
sample = jax .random .exponential (sampling_key , size , dtype ) * scale
247
- rng ["jax_state" ] = rng_key
248
- return (rng , sample )
265
+ return (rng_key , sample )
249
266
250
267
return sample_fn
251
268
@@ -254,17 +271,15 @@ def sample_fn(rng, size, dtype, *parameters):
254
271
def jax_sample_fn_t (op ):
255
272
"""JAX implementation of `StudentTRV`."""
256
273
257
- def sample_fn (rng , size , dtype , * parameters ):
258
- rng_key = rng ["jax_state" ]
274
+ def sample_fn (rng_key , size , dtype , * parameters ):
259
275
rng_key , sampling_key = jax .random .split (rng_key , 2 )
260
276
(
261
277
df ,
262
278
loc ,
263
279
scale ,
264
280
) = parameters
265
281
sample = loc + jax .random .t (sampling_key , df , size , dtype ) * scale
266
- rng ["jax_state" ] = rng_key
267
- return (rng , sample )
282
+ return (rng_key , sample )
268
283
269
284
return sample_fn
270
285
@@ -273,13 +288,11 @@ def sample_fn(rng, size, dtype, *parameters):
273
288
def jax_funcify_choice (op ):
274
289
"""JAX implementation of `ChoiceRV`."""
275
290
276
- def sample_fn (rng , size , dtype , * parameters ):
277
- rng_key = rng ["jax_state" ]
291
+ def sample_fn (rng_key , size , dtype , * parameters ):
278
292
rng_key , sampling_key = jax .random .split (rng_key , 2 )
279
293
(a , p , replace ) = parameters
280
294
smpl_value = jax .random .choice (sampling_key , a , size , replace , p )
281
- rng ["jax_state" ] = rng_key
282
- return (rng , smpl_value )
295
+ return (rng_key , smpl_value )
283
296
284
297
return sample_fn
285
298
@@ -288,13 +301,11 @@ def sample_fn(rng, size, dtype, *parameters):
288
301
def jax_sample_fn_permutation (op ):
289
302
"""JAX implementation of `PermutationRV`."""
290
303
291
- def sample_fn (rng , size , dtype , * parameters ):
292
- rng_key = rng ["jax_state" ]
304
+ def sample_fn (rng_key , size , dtype , * parameters ):
293
305
rng_key , sampling_key = jax .random .split (rng_key , 2 )
294
306
(x ,) = parameters
295
307
sample = jax .random .permutation (sampling_key , x )
296
- rng ["jax_state" ] = rng_key
297
- return (rng , sample )
308
+ return (rng_key , sample )
298
309
299
310
return sample_fn
300
311
@@ -309,15 +320,12 @@ def jax_sample_fn_binomial(op):
309
320
310
321
from numpyro .distributions .util import binomial
311
322
312
- def sample_fn (rng , size , dtype , n , p ):
313
- rng_key = rng ["jax_state" ]
323
+ def sample_fn (rng_key , size , dtype , n , p ):
314
324
rng_key , sampling_key = jax .random .split (rng_key , 2 )
315
325
316
326
sample = binomial (key = sampling_key , n = n , p = p , shape = size )
317
327
318
- rng ["jax_state" ] = rng_key
319
-
320
- return (rng , sample )
328
+ return (rng_key , sample )
321
329
322
330
return sample_fn
323
331
@@ -332,15 +340,12 @@ def jax_sample_fn_multinomial(op):
332
340
333
341
from numpyro .distributions .util import multinomial
334
342
335
- def sample_fn (rng , size , dtype , n , p ):
336
- rng_key = rng ["jax_state" ]
343
+ def sample_fn (rng_key , size , dtype , n , p ):
337
344
rng_key , sampling_key = jax .random .split (rng_key , 2 )
338
345
339
346
sample = multinomial (key = sampling_key , n = n , p = p , shape = size )
340
347
341
- rng ["jax_state" ] = rng_key
342
-
343
- return (rng , sample )
348
+ return (rng_key , sample )
344
349
345
350
return sample_fn
346
351
@@ -355,17 +360,14 @@ def jax_sample_fn_vonmises(op):
355
360
356
361
from numpyro .distributions .util import von_mises_centered
357
362
358
- def sample_fn (rng , size , dtype , mu , kappa ):
359
- rng_key = rng ["jax_state" ]
363
+ def sample_fn (rng_key , size , dtype , mu , kappa ):
360
364
rng_key , sampling_key = jax .random .split (rng_key , 2 )
361
365
362
366
sample = von_mises_centered (
363
367
key = sampling_key , concentration = kappa , shape = size , dtype = dtype
364
368
)
365
369
sample = (sample + mu + np .pi ) % (2.0 * np .pi ) - np .pi
366
370
367
- rng ["jax_state" ] = rng_key
368
-
369
- return (rng , sample )
371
+ return (rng_key , sample )
370
372
371
373
return sample_fn
0 commit comments