Skip to content

Commit 09ce5f0

Browse files
committed
Typify RNG input variables in JAX linker
1 parent cd44a2b commit 09ce5f0

File tree

4 files changed

+94
-56
lines changed

4 files changed

+94
-56
lines changed

pytensor/link/basic.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,9 @@ def create_thunk_inputs(self, storage_map: Dict[Variable, List[Any]]) -> List[An
609609
def jit_compile(self, fn: Callable) -> Callable:
610610
"""JIT compile a converted ``FunctionGraph``."""
611611

612+
def typify(self, var: Variable):
613+
return var
614+
612615
def output_filter(self, var: Variable, out: Any) -> Any:
613616
"""Apply a filter to the data output by a JITed function call."""
614617
return out
@@ -735,7 +738,7 @@ def make_all(self, input_storage=None, output_storage=None, storage_map=None):
735738
return (
736739
fn,
737740
[
738-
Container(input, storage)
741+
Container(self.typify(input), storage)
739742
for input, storage in zip(fgraph.inputs, input_storage)
740743
],
741744
[

pytensor/link/jax/dispatch/random.py

Lines changed: 56 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytensor.tensor.random.basic as aer
1111
from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify
1212
from pytensor.link.jax.dispatch.shape import JAXShapeTuple
13+
from pytensor.tensor.random.type import RandomType
1314
from pytensor.tensor.shape import Shape, Shape_i
1415

1516

@@ -55,8 +56,7 @@ def jax_typify_RandomState(state, **kwargs):
5556
state = state.get_state(legacy=False)
5657
state["bit_generator"] = numpy_bit_gens[state["bit_generator"]]
5758
# XXX: Is this a reasonable approach?
58-
state["jax_state"] = state["state"]["key"][0:2]
59-
return state
59+
return state["state"]["key"][0:2]
6060

6161

6262
@jax_typify.register(Generator)
@@ -81,7 +81,36 @@ def jax_typify_Generator(rng, **kwargs):
8181
state_32 = _coerce_to_uint32_array(state["state"]["state"])
8282
state["state"]["inc"] = inc_32[0] << 32 | inc_32[1]
8383
state["state"]["state"] = state_32[0] << 32 | state_32[1]
84-
return state
84+
return state["jax_state"]
85+
86+
87+
class RandomPRNGKeyType(RandomType[jax.random.PRNGKey]):
88+
"""JAX-compatible PRNGKey type.
89+
90+
This type is not exposed to users directly.
91+
92+
It is introduced by the JIT linker in place of any RandomStateType or RandomGeneratorType
93+
input variables used in the original function. Nodes in the function graph will
94+
still show the original types as inputs and outputs.
95+
"""
96+
97+
def filter(self, data, strict: bool = False, allow_downcast=None):
98+
# PRNGs are just JAX Arrays, we assume this is a valid one!
99+
if isinstance(data, jax.Array):
100+
return data
101+
102+
if strict:
103+
raise TypeError()
104+
105+
return jax_typify(data)
106+
107+
108+
random_prng_key_type = RandomPRNGKeyType()
109+
110+
111+
@jax_typify.register(RandomType)
112+
def jax_typify_RandomType(type):
113+
return random_prng_key_type()
85114

86115

87116
@jax_funcify.register(aer.RandomVariable)
@@ -128,12 +157,10 @@ def jax_sample_fn_generic(op):
128157
name = op.name
129158
jax_op = getattr(jax.random, name)
130159

131-
def sample_fn(rng, size, dtype, *parameters):
132-
rng_key = rng["jax_state"]
160+
def sample_fn(rng_key, size, dtype, *parameters):
133161
rng_key, sampling_key = jax.random.split(rng_key, 2)
134162
sample = jax_op(sampling_key, *parameters, shape=size, dtype=dtype)
135-
rng["jax_state"] = rng_key
136-
return (rng, sample)
163+
return (rng_key, sample)
137164

138165
return sample_fn
139166

@@ -155,13 +182,11 @@ def jax_sample_fn_loc_scale(op):
155182
name = op.name
156183
jax_op = getattr(jax.random, name)
157184

158-
def sample_fn(rng, size, dtype, *parameters):
159-
rng_key = rng["jax_state"]
185+
def sample_fn(rng_key, size, dtype, *parameters):
160186
rng_key, sampling_key = jax.random.split(rng_key, 2)
161187
loc, scale = parameters
162188
sample = loc + jax_op(sampling_key, size, dtype) * scale
163-
rng["jax_state"] = rng_key
164-
return (rng, sample)
189+
return (rng_key, sample)
165190

166191
return sample_fn
167192

@@ -173,12 +198,10 @@ def jax_sample_fn_no_dtype(op):
173198
name = op.name
174199
jax_op = getattr(jax.random, name)
175200

176-
def sample_fn(rng, size, dtype, *parameters):
177-
rng_key = rng["jax_state"]
201+
def sample_fn(rng_key, size, dtype, *parameters):
178202
rng_key, sampling_key = jax.random.split(rng_key, 2)
179203
sample = jax_op(sampling_key, *parameters, shape=size)
180-
rng["jax_state"] = rng_key
181-
return (rng, sample)
204+
return (rng_key, sample)
182205

183206
return sample_fn
184207

@@ -199,15 +222,13 @@ def jax_sample_fn_uniform(op):
199222
name = "randint"
200223
jax_op = getattr(jax.random, name)
201224

202-
def sample_fn(rng, size, dtype, *parameters):
203-
rng_key = rng["jax_state"]
225+
def sample_fn(rng_key, size, dtype, *parameters):
204226
rng_key, sampling_key = jax.random.split(rng_key, 2)
205227
minval, maxval = parameters
206228
sample = jax_op(
207229
sampling_key, shape=size, dtype=dtype, minval=minval, maxval=maxval
208230
)
209-
rng["jax_state"] = rng_key
210-
return (rng, sample)
231+
return (rng_key, sample)
211232

212233
return sample_fn
213234

@@ -224,13 +245,11 @@ def jax_sample_fn_shape_rate(op):
224245
name = op.name
225246
jax_op = getattr(jax.random, name)
226247

227-
def sample_fn(rng, size, dtype, *parameters):
228-
rng_key = rng["jax_state"]
248+
def sample_fn(rng_key, size, dtype, *parameters):
229249
rng_key, sampling_key = jax.random.split(rng_key, 2)
230250
(shape, rate) = parameters
231251
sample = jax_op(sampling_key, shape, size, dtype) / rate
232-
rng["jax_state"] = rng_key
233-
return (rng, sample)
252+
return (rng_key, sample)
234253

235254
return sample_fn
236255

@@ -239,13 +258,11 @@ def sample_fn(rng, size, dtype, *parameters):
239258
def jax_sample_fn_exponential(op):
240259
"""JAX implementation of `ExponentialRV`."""
241260

242-
def sample_fn(rng, size, dtype, *parameters):
243-
rng_key = rng["jax_state"]
261+
def sample_fn(rng_key, size, dtype, *parameters):
244262
rng_key, sampling_key = jax.random.split(rng_key, 2)
245263
(scale,) = parameters
246264
sample = jax.random.exponential(sampling_key, size, dtype) * scale
247-
rng["jax_state"] = rng_key
248-
return (rng, sample)
265+
return (rng_key, sample)
249266

250267
return sample_fn
251268

@@ -254,17 +271,15 @@ def sample_fn(rng, size, dtype, *parameters):
254271
def jax_sample_fn_t(op):
255272
"""JAX implementation of `StudentTRV`."""
256273

257-
def sample_fn(rng, size, dtype, *parameters):
258-
rng_key = rng["jax_state"]
274+
def sample_fn(rng_key, size, dtype, *parameters):
259275
rng_key, sampling_key = jax.random.split(rng_key, 2)
260276
(
261277
df,
262278
loc,
263279
scale,
264280
) = parameters
265281
sample = loc + jax.random.t(sampling_key, df, size, dtype) * scale
266-
rng["jax_state"] = rng_key
267-
return (rng, sample)
282+
return (rng_key, sample)
268283

269284
return sample_fn
270285

@@ -273,13 +288,11 @@ def sample_fn(rng, size, dtype, *parameters):
273288
def jax_funcify_choice(op):
274289
"""JAX implementation of `ChoiceRV`."""
275290

276-
def sample_fn(rng, size, dtype, *parameters):
277-
rng_key = rng["jax_state"]
291+
def sample_fn(rng_key, size, dtype, *parameters):
278292
rng_key, sampling_key = jax.random.split(rng_key, 2)
279293
(a, p, replace) = parameters
280294
smpl_value = jax.random.choice(sampling_key, a, size, replace, p)
281-
rng["jax_state"] = rng_key
282-
return (rng, smpl_value)
295+
return (rng_key, smpl_value)
283296

284297
return sample_fn
285298

@@ -288,13 +301,11 @@ def sample_fn(rng, size, dtype, *parameters):
288301
def jax_sample_fn_permutation(op):
289302
"""JAX implementation of `PermutationRV`."""
290303

291-
def sample_fn(rng, size, dtype, *parameters):
292-
rng_key = rng["jax_state"]
304+
def sample_fn(rng_key, size, dtype, *parameters):
293305
rng_key, sampling_key = jax.random.split(rng_key, 2)
294306
(x,) = parameters
295307
sample = jax.random.permutation(sampling_key, x)
296-
rng["jax_state"] = rng_key
297-
return (rng, sample)
308+
return (rng_key, sample)
298309

299310
return sample_fn
300311

@@ -309,15 +320,12 @@ def jax_sample_fn_binomial(op):
309320

310321
from numpyro.distributions.util import binomial
311322

312-
def sample_fn(rng, size, dtype, n, p):
313-
rng_key = rng["jax_state"]
323+
def sample_fn(rng_key, size, dtype, n, p):
314324
rng_key, sampling_key = jax.random.split(rng_key, 2)
315325

316326
sample = binomial(key=sampling_key, n=n, p=p, shape=size)
317327

318-
rng["jax_state"] = rng_key
319-
320-
return (rng, sample)
328+
return (rng_key, sample)
321329

322330
return sample_fn
323331

@@ -332,15 +340,12 @@ def jax_sample_fn_multinomial(op):
332340

333341
from numpyro.distributions.util import multinomial
334342

335-
def sample_fn(rng, size, dtype, n, p):
336-
rng_key = rng["jax_state"]
343+
def sample_fn(rng_key, size, dtype, n, p):
337344
rng_key, sampling_key = jax.random.split(rng_key, 2)
338345

339346
sample = multinomial(key=sampling_key, n=n, p=p, shape=size)
340347

341-
rng["jax_state"] = rng_key
342-
343-
return (rng, sample)
348+
return (rng_key, sample)
344349

345350
return sample_fn
346351

@@ -355,17 +360,14 @@ def jax_sample_fn_vonmises(op):
355360

356361
from numpyro.distributions.util import von_mises_centered
357362

358-
def sample_fn(rng, size, dtype, mu, kappa):
359-
rng_key = rng["jax_state"]
363+
def sample_fn(rng_key, size, dtype, mu, kappa):
360364
rng_key, sampling_key = jax.random.split(rng_key, 2)
361365

362366
sample = von_mises_centered(
363367
key=sampling_key, concentration=kappa, shape=size, dtype=dtype
364368
)
365369
sample = (sample + mu + np.pi) % (2.0 * np.pi) - np.pi
366370

367-
rng["jax_state"] = rng_key
368-
369-
return (rng, sample)
371+
return (rng_key, sample)
370372

371373
return sample_fn

pytensor/link/jax/linker.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from numpy.random import Generator, RandomState
44

55
from pytensor.compile.sharedvalue import SharedVariable, shared
6-
from pytensor.graph.basic import Constant
6+
from pytensor.graph.basic import Constant, Variable
77
from pytensor.link.basic import JITLinker
88

99

@@ -63,6 +63,11 @@ def jit_compile(self, fn):
6363
]
6464
return jax.jit(fn, static_argnums=static_argnums)
6565

66+
def typify(self, var: Variable):
67+
from pytensor.link.jax.dispatch import jax_typify
68+
69+
return jax_typify(var.type)
70+
6671
def create_thunk_inputs(self, storage_map):
6772
from pytensor.link.jax.dispatch import jax_typify
6873

tests/link/jax/test_random.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pytensor.graph.basic import Constant
1313
from pytensor.graph.fg import FunctionGraph
1414
from pytensor.tensor.random.basic import RandomVariable
15+
from pytensor.tensor.random.type import random_generator_type
1516
from pytensor.tensor.random.utils import RandomStream
1617
from tests.link.jax.test_basic import compare_jax_and_py, jax_mode, set_test_value
1718

@@ -22,6 +23,33 @@
2223
from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402
2324

2425

26+
def test_rng_io():
27+
rng = random_generator_type("rng")
28+
next_rng, x = aer.normal(rng=rng).owner.outputs
29+
fn = pytensor.function([rng], [next_rng, x], mode="JAX")
30+
31+
np_rng = np.random.default_rng(0)
32+
np_rst = np.random.RandomState(1)
33+
jx_rng = jax.random.PRNGKey(2)
34+
35+
# Inputs - RNG outputs
36+
assert isinstance(fn(np_rng)[0], jax.Array)
37+
assert isinstance(fn(np_rst)[0], jax.Array)
38+
assert isinstance(fn(jx_rng)[0], jax.Array)
39+
40+
# Inputs - Value outputs
41+
assert fn(np_rng)[1] == fn(np_rng)[1]
42+
assert fn(np_rst)[1] == fn(np_rst)[1]
43+
assert fn(jx_rng)[1] == fn(jx_rng)[1]
44+
assert fn(np_rng)[1] != fn(np_rst)[1]
45+
assert fn(np_rng)[1] != fn(jx_rng)[1]
46+
47+
# Chained Inputs - RNG / Value outputs
48+
assert fn(fn(np_rng)[0])[1] != fn(np_rng)[1]
49+
assert fn(fn(np_rst)[0])[1] != fn(np_rst)[1]
50+
assert fn(fn(jx_rng)[0])[1] != fn(jx_rng)[1]
51+
52+
2553
def test_random_RandomStream():
2654
"""Two successive calls of a compiled graph using `RandomStream` should
2755
return different values.

0 commit comments

Comments
 (0)