Skip to content

Commit 76d27b2

Browse files
committed
Fix Scan JAX dispatcher
1 parent 2dc912d commit 76d27b2

File tree

2 files changed

+293
-70
lines changed

2 files changed

+293
-70
lines changed

pytensor/link/jax/dispatch/scan.py

Lines changed: 139 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,36 @@
11
import jax
22
import jax.numpy as jnp
33

4-
from pytensor.graph.fg import FunctionGraph
54
from pytensor.link.jax.dispatch.basic import jax_funcify
65
from pytensor.scan.op import Scan
76
from pytensor.scan.utils import ScanArgs
87

98

109
@jax_funcify.register(Scan)
1110
def jax_funcify_Scan(op, **kwargs):
12-
inner_fg = FunctionGraph(op.inputs, op.outputs)
13-
jax_at_inner_func = jax_funcify(inner_fg, **kwargs)
11+
info = op.info
12+
13+
if info.as_while:
14+
raise NotImplementedError("While Scan cannot yet be converted to JAX")
15+
16+
if info.n_mit_mot:
17+
raise NotImplementedError(
18+
"Scan with MIT-MOT (gradients of scan) cannot yet be converted to JAX"
19+
)
20+
21+
# Optimize inner graph
22+
fgraph = op.fgraph.clone()
23+
rewriter = op.mode_instance.optimizer
24+
rewriter(fgraph)
25+
scan_inner_func = jax_funcify(fgraph, **kwargs)
1426

1527
def scan(*outer_inputs):
1628
scan_args = ScanArgs(
17-
list(outer_inputs), [None] * op.info.n_outs, op.inputs, op.outputs, op.info
29+
list(outer_inputs),
30+
[None] * len(op.inner_outputs),
31+
op.inner_inputs,
32+
op.inner_outputs,
33+
op.info,
1834
)
1935

2036
# `outer_inputs` is a list with the following composite form:
@@ -29,31 +45,23 @@ def scan(*outer_inputs):
2945
n_steps = scan_args.n_steps
3046
seqs = scan_args.outer_in_seqs
3147

32-
# TODO: mit_mots
33-
mit_mot_in_slices = []
34-
3548
mit_sot_in_slices = []
3649
for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot):
37-
neg_taps = [abs(t) for t in tap if t < 0]
38-
pos_taps = [abs(t) for t in tap if t > 0]
39-
max_neg = max(neg_taps) if neg_taps else 0
40-
max_pos = max(pos_taps) if pos_taps else 0
41-
init_slice = seq[: max_neg + max_pos]
50+
init_slice = seq[: abs(min(tap))]
4251
mit_sot_in_slices.append(init_slice)
4352

4453
sit_sot_in_slices = [seq[0] for seq in scan_args.outer_in_sit_sot]
4554

4655
init_carry = (
47-
mit_mot_in_slices,
56+
[], # mit_mot_in_slices
4857
mit_sot_in_slices,
4958
sit_sot_in_slices,
5059
scan_args.outer_in_shared,
5160
scan_args.outer_in_non_seqs,
5261
)
5362

5463
def jax_args_to_inner_scan(op, carry, x):
55-
# `carry` contains all inner-output taps, non_seqs, and shared
56-
# terms
64+
# `carry` contains all inner-output taps, non_seqs, and shared terms
5765
(
5866
inner_in_mit_mot,
5967
inner_in_mit_sot,
@@ -76,6 +84,7 @@ def jax_args_to_inner_scan(op, carry, x):
7684
for array, index in zip(inner_in_mit_sot, scan_args.mit_sot_in_slices):
7785
inner_in_mit_sot_flatten.extend(array[jnp.array(index)])
7886

87+
# Concatenate lists
7988
inner_scan_inputs = sum(
8089
[
8190
inner_in_seqs,
@@ -103,57 +112,131 @@ def inner_scan_outs_to_jax_outs(
103112
inner_in_non_seqs,
104113
) = old_carry
105114

106-
def update_mit_sot(mit_sot, new_val):
107-
return jnp.concatenate([mit_sot[1:], new_val[None, ...]], axis=0)
115+
inner_out_mit_sot = inner_scan_outs[
116+
info.n_mit_mot : info.n_mit_mot + info.n_mit_sot
117+
]
118+
inner_in_mit_sot_new = []
119+
if inner_in_mit_sot:
120+
# Replace the oldest tap by the newest value
121+
inner_in_mit_sot_new = [
122+
jnp.concatenate([old_mit_sot[1:], new_val[None, ...]], axis=0)
123+
for old_mit_sot, new_val in zip(
124+
inner_in_mit_sot,
125+
inner_out_mit_sot,
126+
)
127+
]
128+
129+
inner_out_sit_sot = inner_in_sit_sot_new = inner_scan_outs[
130+
info.n_mit_mot
131+
+ info.n_mit_sot : info.n_mit_mot
132+
+ info.n_mit_sot
133+
+ info.n_sit_sot
134+
]
108135

109-
inner_out_mit_sot = [
110-
update_mit_sot(mit_sot, new_val)
111-
for mit_sot, new_val in zip(inner_in_mit_sot, inner_scan_outs)
136+
inner_out_nit_sot = inner_scan_outs[
137+
info.n_mit_mot
138+
+ info.n_mit_sot
139+
+ info.n_sit_sot : info.n_mit_mot
140+
+ info.n_mit_sot
141+
+ info.n_sit_sot
142+
+ info.n_nit_sot :
112143
]
113144

114-
# This should contain all inner-output taps, non_seqs, and shared
115-
# terms
116-
if not inner_in_sit_sot:
117-
inner_out_sit_sot = []
118-
else:
119-
inner_out_sit_sot = inner_scan_outs
145+
inner_in_shared_new = inner_in_shared
146+
if info.n_shared_outs:
147+
# Replace old shared inputs by new shared outputs
148+
new_inner_out_shared = inner_scan_outs[
149+
info.n_mit_mot + info.n_mit_sot + info.n_sit_sot + info.n_nit_sot :
150+
]
151+
inner_in_shared_new[: info.n_shared_outs] = new_inner_out_shared
152+
120153
new_carry = (
121-
inner_in_mit_mot,
122-
inner_out_mit_sot,
123-
inner_out_sit_sot,
124-
inner_in_shared,
154+
[], # MIT-MOT
155+
inner_in_mit_sot_new,
156+
inner_in_sit_sot_new,
157+
inner_in_shared_new,
125158
inner_in_non_seqs,
126159
)
127160

128-
return new_carry
161+
# Shared variables and non_seqs are not traced
162+
new_scan = sum(
163+
[
164+
[], # MIT-MOT
165+
inner_out_mit_sot,
166+
inner_out_sit_sot,
167+
inner_out_nit_sot,
168+
],
169+
[],
170+
)
171+
172+
return new_carry, new_scan
129173

130174
def jax_inner_func(carry, x):
131175
inner_args = jax_args_to_inner_scan(op, carry, x)
132-
inner_scan_outs = list(jax_at_inner_func(*inner_args))
133-
new_carry = inner_scan_outs_to_jax_outs(op, carry, inner_scan_outs)
134-
return new_carry, inner_scan_outs
135-
136-
_, scan_out = jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps)
137-
138-
# We need to prepend the initial values so that the JAX output will
139-
# match the raw `Scan` `Op` output and, thus, work with a downstream
140-
# `Subtensor` `Op` introduced by the `scan` helper function.
141-
def append_scan_out(scan_in_part, scan_out_part):
142-
return jnp.concatenate([scan_in_part[:-n_steps], scan_out_part], axis=0)
143-
144-
if scan_args.outer_in_mit_sot:
145-
scan_out_final = [
146-
append_scan_out(init, out)
147-
for init, out in zip(scan_args.outer_in_mit_sot, scan_out)
148-
]
149-
elif scan_args.outer_in_sit_sot:
150-
scan_out_final = [
151-
append_scan_out(init, out)
152-
for init, out in zip(scan_args.outer_in_sit_sot, scan_out)
153-
]
176+
inner_scan_outs = list(scan_inner_func(*inner_args))
177+
new_carry, new_scan_outs = inner_scan_outs_to_jax_outs(
178+
op, carry, inner_scan_outs
179+
)
180+
return new_carry, new_scan_outs
181+
182+
last_state, scan_traces = jax.lax.scan(
183+
jax_inner_func, init_carry, seqs, length=n_steps
184+
)
185+
186+
def get_partial_traces(scan_traces):
187+
# We need to prepend the initial values so that the JAX output will
188+
# match the raw `Scan` `Op` output and, thus, work with a downstream
189+
# `Subtensor` `Op` introduced by the `scan` helper function.
190+
init_states = (
191+
mit_sot_in_slices
192+
+ sit_sot_in_slices
193+
+ [None] * len(scan_args.outer_in_nit_sot)
194+
)
195+
buffers = (
196+
scan_args.outer_in_mit_sot
197+
+ scan_args.outer_in_sit_sot
198+
+ scan_args.outer_in_nit_sot
199+
)
200+
201+
partial_scan_traces = []
202+
for init_state, scan_trace, buffer in zip(
203+
init_states, scan_traces, buffers
204+
):
205+
if init_state is not None:
206+
# MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
207+
full_scan_trace = jnp.concatenate(
208+
[jnp.atleast_1d(init_state), jnp.atleast_1d(scan_trace)],
209+
axis=0,
210+
)
211+
partial_scan_trace = full_scan_trace[-buffer.shape[0] :]
212+
else:
213+
# NIT-SOT: Buffer is just the number of entries that should be returned
214+
partial_scan_trace = jnp.atleast_1d(scan_trace)[-buffer:]
215+
partial_scan_traces.append(partial_scan_trace)
216+
217+
return partial_scan_traces
218+
219+
def get_shared_outs(last_state):
220+
# Select the last state of shared_outs, these outputs are not traced
221+
if not info.n_shared_outs:
222+
return []
223+
224+
(
225+
inner_out_mit_mot,
226+
inner_out_mit_sot,
227+
inner_out_sit_sot,
228+
inner_out_shared,
229+
inner_in_non_seqs,
230+
) = last_state
231+
232+
# TODO: Check if a shared variable that is not an output shows up here or in non-seqs
233+
shared_outs = inner_out_shared[: info.n_shared_outs]
234+
return list(shared_outs)
235+
236+
scan_outs_final = get_partial_traces(scan_traces) + get_shared_outs(last_state)
154237

155-
if len(scan_out_final) == 1:
156-
scan_out_final = scan_out_final[0]
157-
return scan_out_final
238+
if len(scan_outs_final) == 1:
239+
scan_outs_final = scan_outs_final[0]
240+
return scan_outs_final
158241

159242
return scan

0 commit comments

Comments
 (0)