1
1
import jax
2
2
import jax .numpy as jnp
3
3
4
- from pytensor .graph .fg import FunctionGraph
5
4
from pytensor .link .jax .dispatch .basic import jax_funcify
6
5
from pytensor .scan .op import Scan
7
6
from pytensor .scan .utils import ScanArgs
8
7
9
8
10
9
@jax_funcify .register (Scan )
11
10
def jax_funcify_Scan (op , ** kwargs ):
12
- inner_fg = FunctionGraph (op .inputs , op .outputs )
13
- jax_at_inner_func = jax_funcify (inner_fg , ** kwargs )
11
+ info = op .info
12
+
13
+ if info .as_while :
14
+ raise NotImplementedError ("While Scan cannot yet be converted to JAX" )
15
+
16
+ if info .n_mit_mot :
17
+ raise NotImplementedError (
18
+ "Scan with MIT-MOT (gradients of scan) cannot yet be converted to JAX"
19
+ )
20
+
21
+ # Optimize inner graph
22
+ fgraph = op .fgraph .clone ()
23
+ rewriter = op .mode_instance .optimizer
24
+ rewriter (fgraph )
25
+ scan_inner_func = jax_funcify (fgraph , ** kwargs )
14
26
15
27
def scan (* outer_inputs ):
16
28
scan_args = ScanArgs (
17
- list (outer_inputs ), [None ] * op .info .n_outs , op .inputs , op .outputs , op .info
29
+ list (outer_inputs ),
30
+ [None ] * len (op .inner_outputs ),
31
+ op .inner_inputs ,
32
+ op .inner_outputs ,
33
+ op .info ,
18
34
)
19
35
20
36
# `outer_inputs` is a list with the following composite form:
@@ -29,31 +45,23 @@ def scan(*outer_inputs):
29
45
n_steps = scan_args .n_steps
30
46
seqs = scan_args .outer_in_seqs
31
47
32
- # TODO: mit_mots
33
- mit_mot_in_slices = []
34
-
35
48
mit_sot_in_slices = []
36
49
for tap , seq in zip (scan_args .mit_sot_in_slices , scan_args .outer_in_mit_sot ):
37
- neg_taps = [abs (t ) for t in tap if t < 0 ]
38
- pos_taps = [abs (t ) for t in tap if t > 0 ]
39
- max_neg = max (neg_taps ) if neg_taps else 0
40
- max_pos = max (pos_taps ) if pos_taps else 0
41
- init_slice = seq [: max_neg + max_pos ]
50
+ init_slice = seq [: abs (min (tap ))]
42
51
mit_sot_in_slices .append (init_slice )
43
52
44
53
sit_sot_in_slices = [seq [0 ] for seq in scan_args .outer_in_sit_sot ]
45
54
46
55
init_carry = (
47
- mit_mot_in_slices ,
56
+ [], # mit_mot_in_slices
48
57
mit_sot_in_slices ,
49
58
sit_sot_in_slices ,
50
59
scan_args .outer_in_shared ,
51
60
scan_args .outer_in_non_seqs ,
52
61
)
53
62
54
63
def jax_args_to_inner_scan (op , carry , x ):
55
- # `carry` contains all inner-output taps, non_seqs, and shared
56
- # terms
64
+ # `carry` contains all inner-output taps, non_seqs, and shared terms
57
65
(
58
66
inner_in_mit_mot ,
59
67
inner_in_mit_sot ,
@@ -76,6 +84,7 @@ def jax_args_to_inner_scan(op, carry, x):
76
84
for array , index in zip (inner_in_mit_sot , scan_args .mit_sot_in_slices ):
77
85
inner_in_mit_sot_flatten .extend (array [jnp .array (index )])
78
86
87
+ # Concatenate lists
79
88
inner_scan_inputs = sum (
80
89
[
81
90
inner_in_seqs ,
@@ -103,57 +112,131 @@ def inner_scan_outs_to_jax_outs(
103
112
inner_in_non_seqs ,
104
113
) = old_carry
105
114
106
- def update_mit_sot (mit_sot , new_val ):
107
- return jnp .concatenate ([mit_sot [1 :], new_val [None , ...]], axis = 0 )
115
+ inner_out_mit_sot = inner_scan_outs [
116
+ info .n_mit_mot : info .n_mit_mot + info .n_mit_sot
117
+ ]
118
+ inner_in_mit_sot_new = []
119
+ if inner_in_mit_sot :
120
+ # Replace the oldest tap by the newest value
121
+ inner_in_mit_sot_new = [
122
+ jnp .concatenate ([old_mit_sot [1 :], new_val [None , ...]], axis = 0 )
123
+ for old_mit_sot , new_val in zip (
124
+ inner_in_mit_sot ,
125
+ inner_out_mit_sot ,
126
+ )
127
+ ]
128
+
129
+ inner_out_sit_sot = inner_in_sit_sot_new = inner_scan_outs [
130
+ info .n_mit_mot
131
+ + info .n_mit_sot : info .n_mit_mot
132
+ + info .n_mit_sot
133
+ + info .n_sit_sot
134
+ ]
108
135
109
- inner_out_mit_sot = [
110
- update_mit_sot (mit_sot , new_val )
111
- for mit_sot , new_val in zip (inner_in_mit_sot , inner_scan_outs )
136
+ inner_out_nit_sot = inner_scan_outs [
137
+ info .n_mit_mot
138
+ + info .n_mit_sot
139
+ + info .n_sit_sot : info .n_mit_mot
140
+ + info .n_mit_sot
141
+ + info .n_sit_sot
142
+ + info .n_nit_sot :
112
143
]
113
144
114
- # This should contain all inner-output taps, non_seqs, and shared
115
- # terms
116
- if not inner_in_sit_sot :
117
- inner_out_sit_sot = []
118
- else :
119
- inner_out_sit_sot = inner_scan_outs
145
+ inner_in_shared_new = inner_in_shared
146
+ if info .n_shared_outs :
147
+ # Replace old shared inputs by new shared outputs
148
+ new_inner_out_shared = inner_scan_outs [
149
+ info .n_mit_mot + info .n_mit_sot + info .n_sit_sot + info .n_nit_sot :
150
+ ]
151
+ inner_in_shared_new [: info .n_shared_outs ] = new_inner_out_shared
152
+
120
153
new_carry = (
121
- inner_in_mit_mot ,
122
- inner_out_mit_sot ,
123
- inner_out_sit_sot ,
124
- inner_in_shared ,
154
+ [], # MIT-MOT
155
+ inner_in_mit_sot_new ,
156
+ inner_in_sit_sot_new ,
157
+ inner_in_shared_new ,
125
158
inner_in_non_seqs ,
126
159
)
127
160
128
- return new_carry
161
+ # Shared variables and non_seqs are not traced
162
+ new_scan = sum (
163
+ [
164
+ [], # MIT-MOT
165
+ inner_out_mit_sot ,
166
+ inner_out_sit_sot ,
167
+ inner_out_nit_sot ,
168
+ ],
169
+ [],
170
+ )
171
+
172
+ return new_carry , new_scan
129
173
130
174
def jax_inner_func (carry , x ):
131
175
inner_args = jax_args_to_inner_scan (op , carry , x )
132
- inner_scan_outs = list (jax_at_inner_func (* inner_args ))
133
- new_carry = inner_scan_outs_to_jax_outs (op , carry , inner_scan_outs )
134
- return new_carry , inner_scan_outs
135
-
136
- _ , scan_out = jax .lax .scan (jax_inner_func , init_carry , seqs , length = n_steps )
137
-
138
- # We need to prepend the initial values so that the JAX output will
139
- # match the raw `Scan` `Op` output and, thus, work with a downstream
140
- # `Subtensor` `Op` introduced by the `scan` helper function.
141
- def append_scan_out (scan_in_part , scan_out_part ):
142
- return jnp .concatenate ([scan_in_part [:- n_steps ], scan_out_part ], axis = 0 )
143
-
144
- if scan_args .outer_in_mit_sot :
145
- scan_out_final = [
146
- append_scan_out (init , out )
147
- for init , out in zip (scan_args .outer_in_mit_sot , scan_out )
148
- ]
149
- elif scan_args .outer_in_sit_sot :
150
- scan_out_final = [
151
- append_scan_out (init , out )
152
- for init , out in zip (scan_args .outer_in_sit_sot , scan_out )
153
- ]
176
+ inner_scan_outs = list (scan_inner_func (* inner_args ))
177
+ new_carry , new_scan_outs = inner_scan_outs_to_jax_outs (
178
+ op , carry , inner_scan_outs
179
+ )
180
+ return new_carry , new_scan_outs
181
+
182
+ last_state , scan_traces = jax .lax .scan (
183
+ jax_inner_func , init_carry , seqs , length = n_steps
184
+ )
185
+
186
+ def get_partial_traces (scan_traces ):
187
+ # We need to prepend the initial values so that the JAX output will
188
+ # match the raw `Scan` `Op` output and, thus, work with a downstream
189
+ # `Subtensor` `Op` introduced by the `scan` helper function.
190
+ init_states = (
191
+ mit_sot_in_slices
192
+ + sit_sot_in_slices
193
+ + [None ] * len (scan_args .outer_in_nit_sot )
194
+ )
195
+ buffers = (
196
+ scan_args .outer_in_mit_sot
197
+ + scan_args .outer_in_sit_sot
198
+ + scan_args .outer_in_nit_sot
199
+ )
200
+
201
+ partial_scan_traces = []
202
+ for init_state , scan_trace , buffer in zip (
203
+ init_states , scan_traces , buffers
204
+ ):
205
+ if init_state is not None :
206
+ # MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
207
+ full_scan_trace = jnp .concatenate (
208
+ [jnp .atleast_1d (init_state ), jnp .atleast_1d (scan_trace )],
209
+ axis = 0 ,
210
+ )
211
+ partial_scan_trace = full_scan_trace [- buffer .shape [0 ] :]
212
+ else :
213
+ # NIT-SOT: Buffer is just the number of entries that should be returned
214
+ partial_scan_trace = jnp .atleast_1d (scan_trace )[- buffer :]
215
+ partial_scan_traces .append (partial_scan_trace )
216
+
217
+ return partial_scan_traces
218
+
219
+ def get_shared_outs (last_state ):
220
+ # Select the last state of shared_outs, these outputs are not traced
221
+ if not info .n_shared_outs :
222
+ return []
223
+
224
+ (
225
+ inner_out_mit_mot ,
226
+ inner_out_mit_sot ,
227
+ inner_out_sit_sot ,
228
+ inner_out_shared ,
229
+ inner_in_non_seqs ,
230
+ ) = last_state
231
+
232
+ # TODO: Check if a shared variable that is not an output shows up here or in non-seqs
233
+ shared_outs = inner_out_shared [: info .n_shared_outs ]
234
+ return list (shared_outs )
235
+
236
+ scan_outs_final = get_partial_traces (scan_traces ) + get_shared_outs (last_state )
154
237
155
- if len (scan_out_final ) == 1 :
156
- scan_out_final = scan_out_final [0 ]
157
- return scan_out_final
238
+ if len (scan_outs_final ) == 1 :
239
+ scan_outs_final = scan_outs_final [0 ]
240
+ return scan_outs_final
158
241
159
242
return scan
0 commit comments