1
+ import jax .random
1
2
import numpy as np
2
3
import pymc as pm
3
4
import pytensor
4
5
import pytensor .tensor as pt
5
6
from pymc import intX
6
7
from pymc .distributions .dist_math import check_parameters
7
8
from pymc .distributions .distribution import Continuous , SymbolicRandomVariable
9
+ from pymc .distributions .multivariate import MvNormal
8
10
from pymc .distributions .shape_utils import get_support_shape , get_support_shape_1d
9
11
from pymc .logprob .abstract import _logprob
10
12
from pytensor .graph .basic import Node
13
+ from pytensor .link .jax .dispatch .random import jax_sample_fn
14
+ from pytensor .tensor .random .basic import MvNormalRV
11
15
12
16
floatX = pytensor .config .floatX
13
17
COV_ZERO_TOL = 0
18
22
)
19
23
20
24
25
+ def make_signature (sequence_names ):
26
+ states = "s"
27
+ obs = "p"
28
+ exog = "r"
29
+ time = "t"
30
+ state_and_obs = "n"
31
+
32
+ matrix_to_shape = {
33
+ "x0" : (states ,),
34
+ "P0" : (states , states ),
35
+ "c" : (states ,),
36
+ "d" : (obs ,),
37
+ "T" : (states , states ),
38
+ "Z" : (obs , states ),
39
+ "R" : (states , exog ),
40
+ "H" : (obs , obs ),
41
+ "Q" : (exog , exog ),
42
+ }
43
+
44
+ for matrix in sequence_names :
45
+ base_shape = matrix_to_shape [matrix ]
46
+ matrix_to_shape [matrix ] = (time ,) + base_shape
47
+
48
+ signature = "," .join (["(" + "," .join (shapes ) + ")" for shapes in matrix_to_shape .values ()])
49
+
50
+ return f"{ signature } ,[rng]->[rng],({ time } ,{ state_and_obs } )"
51
+
52
+
53
+ class MvNormalSVDRV (MvNormalRV ):
54
+ name = "multivariate_normal"
55
+ signature = "(n),(n,n)->(n)"
56
+ dtype = "floatX"
57
+ _print_name = ("MultivariateNormal" , "\\ operatorname{MultivariateNormal}" )
58
+
59
+
60
+ class MvNormalSVD (MvNormal ):
61
+ """Dummy distribution intended to be rewritten into a JAX multivariate_normal with method="svd".
62
+
63
+ A JAX MvNormal robust to low-rank covariance matrices
64
+ """
65
+
66
+ rv_op = MvNormalSVDRV ()
67
+
68
+
69
+ @jax_sample_fn .register (MvNormalSVDRV )
70
+ def jax_sample_fn_mvnormal_svd (op , node ):
71
+ def sample_fn (rng , size , dtype , * parameters ):
72
+ rng_key = rng ["jax_state" ]
73
+ rng_key , sampling_key = jax .random .split (rng_key , 2 )
74
+ sample = jax .random .multivariate_normal (
75
+ sampling_key , * parameters , shape = size , dtype = dtype , method = "svd"
76
+ )
77
+ rng ["jax_state" ] = rng_key
78
+ return (rng , sample )
79
+
80
+ return sample_fn
81
+
82
+
21
83
class LinearGaussianStateSpaceRV (SymbolicRandomVariable ):
22
84
default_output = 1
23
85
_print_name = ("LinearGuassianStateSpace" , "\\ operatorname{LinearGuassianStateSpace}" )
@@ -28,6 +90,7 @@ def update(self, node: Node):
28
90
29
91
class _LinearGaussianStateSpace (Continuous ):
30
92
rv_op = LinearGaussianStateSpaceRV
93
+ ndim_supp = 2
31
94
32
95
def __new__ (
33
96
cls ,
@@ -91,25 +154,8 @@ def dist(
91
154
[a0 , P0 , c , d , T , Z , R , H , Q , steps ], mode = mode , sequence_names = sequence_names , ** kwargs
92
155
)
93
156
94
- @classmethod
95
- def _get_k_states (cls , T ):
96
- k_states = T .type .shape [0 ]
97
- if k_states is None :
98
- raise ValueError (lgss_shape_message )
99
- return k_states
100
-
101
- @classmethod
102
- def _get_k_endog (cls , H ):
103
- k_endog = H .type .shape [0 ]
104
- if k_endog is None :
105
- raise ValueError (lgss_shape_message )
106
-
107
- return k_endog
108
-
109
157
@classmethod
110
158
def rv_op (cls , a0 , P0 , c , d , T , Z , R , H , Q , steps , size = None , mode = None , sequence_names = None ):
111
- if size is not None :
112
- batch_size = size
113
159
if sequence_names is None :
114
160
sequence_names = []
115
161
@@ -125,77 +171,78 @@ def rv_op(cls, a0, P0, c, d, T, Z, R, H, Q, steps, size=None, mode=None, sequenc
125
171
H_ .name = "H"
126
172
Q_ .name = "Q"
127
173
128
- n_seq = len (sequence_names )
129
174
sequences = [
130
175
x
131
176
for x , name in zip ([c_ , d_ , T_ , Z_ , R_ , H_ , Q_ ], ["c" , "d" , "T" , "Z" , "R" , "H" , "Q" ])
132
177
if name in sequence_names
133
178
]
134
179
non_sequences = [x for x in [c_ , d_ , T_ , Z_ , R_ , H_ , Q_ ] if x not in sequences ]
135
180
136
- steps_ = steps .type ()
137
181
rng = pytensor .shared (np .random .default_rng ())
138
182
139
183
def sort_args (args ):
140
184
sorted_args = []
185
+
186
+ # Inside the scan, outputs_info variables get a time step appended to their name
187
+ # e.g. x -> x[t]. Remove this so we can identify variables by name.
141
188
arg_names = [x .name .replace ("[t]" , "" ) for x in args ]
142
189
190
+ # c, d ,T, Z, R, H, Q is the "canonical" ordering
143
191
for name in ["c" , "d" , "T" , "Z" , "R" , "H" , "Q" ]:
144
192
idx = arg_names .index (name )
145
193
sorted_args .append (args [idx ])
146
194
147
195
return sorted_args
148
196
197
+ n_seq = len (sequence_names )
198
+
149
199
def step_fn (* args ):
150
200
seqs , state , non_seqs = args [:n_seq ], args [n_seq ], args [n_seq + 1 :]
151
201
non_seqs , rng = non_seqs [:- 1 ], non_seqs [- 1 ]
152
202
153
203
c , d , T , Z , R , H , Q = sort_args (seqs + non_seqs )
154
-
155
204
k = T .shape [0 ]
156
205
a = state [:k ]
157
206
158
- middle_rng , a_innovation = pm . MvNormal .dist (mu = 0 , cov = Q , rng = rng ).owner .outputs
159
- next_rng , y_innovation = pm . MvNormal .dist (mu = 0 , cov = H , rng = middle_rng ).owner .outputs
207
+ middle_rng , a_innovation = MvNormalSVD .dist (mu = 0 , cov = Q , rng = rng ).owner .outputs
208
+ next_rng , y_innovation = MvNormalSVD .dist (mu = 0 , cov = H , rng = middle_rng ).owner .outputs
160
209
161
210
a_mu = c + T @ a
162
- a_next = pt . switch ( pt . all ( pt . le ( Q , COV_ZERO_TOL )), a_mu , a_mu + R @ a_innovation )
211
+ a_next = a_mu + R @ a_innovation
163
212
164
213
y_mu = d + Z @ a_next
165
- y_next = pt . switch ( pt . all ( pt . le ( H , COV_ZERO_TOL )), y_mu , y_mu + y_innovation )
214
+ y_next = y_mu + y_innovation
166
215
167
216
next_state = pt .concatenate ([a_next , y_next ], axis = 0 )
168
217
169
218
return next_state , {rng : next_rng }
170
219
171
- init_x_ = pm .MvNormal .dist (a0_ , P0_ , rng = rng )
172
220
Z_init = Z_ if Z_ in non_sequences else Z_ [0 ]
173
221
H_init = H_ if H_ in non_sequences else H_ [0 ]
174
222
175
- init_y_ = pt .switch (
176
- pt .all (pt .le (H_init , COV_ZERO_TOL )),
177
- Z_init @ init_x_ ,
178
- pm .MvNormal .dist (Z_init @ init_x_ , H_init , rng = rng ),
179
- )
223
+ init_x_ = MvNormalSVD .dist (a0_ , P0_ , rng = rng )
224
+ init_y_ = MvNormalSVD .dist (Z_init @ init_x_ , H_init , rng = rng )
225
+
180
226
init_dist_ = pt .concatenate ([init_x_ , init_y_ ], axis = 0 )
181
227
182
228
statespace , updates = pytensor .scan (
183
229
step_fn ,
184
230
outputs_info = [init_dist_ ],
185
231
sequences = None if len (sequences ) == 0 else sequences ,
186
232
non_sequences = non_sequences + [rng ],
187
- n_steps = steps_ ,
233
+ n_steps = steps ,
188
234
mode = mode ,
189
235
strict = True ,
190
236
)
191
237
192
238
statespace_ = pt .concatenate ([init_dist_ [None ], statespace ], axis = 0 )
239
+ statespace_ = pt .specify_shape (statespace_ , (steps + 1 , None ))
193
240
194
241
(ss_rng ,) = tuple (updates .values ())
195
242
linear_gaussian_ss_op = LinearGaussianStateSpaceRV (
196
- inputs = [a0_ , P0_ , c_ , d_ , T_ , Z_ , R_ , H_ , Q_ , steps_ , rng ],
243
+ inputs = [a0_ , P0_ , c_ , d_ , T_ , Z_ , R_ , H_ , Q_ , steps , rng ],
197
244
outputs = [ss_rng , statespace_ ],
198
- ndim_supp = 1 ,
245
+ signature = make_signature ( sequence_names ) ,
199
246
)
200
247
201
248
linear_gaussian_ss = linear_gaussian_ss_op (a0 , P0 , c , d , T , Z , R , H , Q , steps , rng )
@@ -221,10 +268,10 @@ def __new__(
221
268
H ,
222
269
Q ,
223
270
* ,
224
- steps = None ,
225
- mode = None ,
226
- sequence_names = None ,
271
+ steps ,
227
272
k_endog = None ,
273
+ sequence_names = None ,
274
+ mode = None ,
228
275
** kwargs ,
229
276
):
230
277
dims = kwargs .pop ("dims" , None )
@@ -239,35 +286,29 @@ def __new__(
239
286
latent_dims = [time_dim , state_dim ]
240
287
obs_dims = [time_dim , obs_dim ]
241
288
242
- matrices = (a0 , P0 , c , d , T , Z , R , H , Q )
289
+ matrices = ()
290
+
243
291
latent_obs_combined = _LinearGaussianStateSpace (
244
292
f"{ name } _combined" ,
245
- * matrices ,
293
+ a0 ,
294
+ P0 ,
295
+ c ,
296
+ d ,
297
+ T ,
298
+ Z ,
299
+ R ,
300
+ H ,
301
+ Q ,
246
302
steps = steps ,
247
303
mode = mode ,
248
304
sequence_names = sequence_names ,
249
305
** kwargs ,
250
306
)
251
- k_states = T .type .shape [0 ]
252
-
253
- if k_endog is None and k_states is None :
254
- raise ValueError ("Could not infer number of observed states, explicitly pass k_endog." )
255
- if k_endog is not None and k_states is not None :
256
- total_shape = latent_obs_combined .type .shape [- 1 ]
257
- inferred_endog = total_shape - k_states
258
- if inferred_endog != k_endog :
259
- raise ValueError (
260
- f"Inferred k_endog does not agree with provided value ({ inferred_endog } != { k_endog } ). "
261
- f"It is not necessary to provide k_endog when the value can be inferred."
262
- )
263
- latent_slice = slice (None , - k_endog )
264
- obs_slice = slice (- k_endog , None )
265
- elif k_endog is None :
266
- latent_slice = slice (None , k_states )
267
- obs_slice = slice (k_states , None )
268
- else :
269
- latent_slice = slice (None , - k_endog )
270
- obs_slice = slice (- k_endog , None )
307
+ latent_obs_combined = pt .specify_shape (latent_obs_combined , (steps + 1 , None ))
308
+ if k_endog is None :
309
+ k_endog = cls ._get_k_endog (H )
310
+ latent_slice = slice (None , - k_endog )
311
+ obs_slice = slice (- k_endog , None )
271
312
272
313
latent_states = latent_obs_combined [..., latent_slice ]
273
314
obs_states = latent_obs_combined [..., obs_slice ]
@@ -289,10 +330,26 @@ def dist(cls, a0, P0, c, d, T, Z, R, H, Q, *, steps=None, **kwargs):
289
330
290
331
return latent_states , obs_states
291
332
333
+ @classmethod
334
+ def _get_k_states (cls , T ):
335
+ k_states = T .type .shape [0 ]
336
+ if k_states is None :
337
+ raise ValueError (lgss_shape_message )
338
+ return k_states
339
+
340
+ @classmethod
341
+ def _get_k_endog (cls , H ):
342
+ k_endog = H .type .shape [0 ]
343
+ if k_endog is None :
344
+ raise ValueError (lgss_shape_message )
345
+
346
+ return k_endog
347
+
292
348
293
349
class KalmanFilterRV (SymbolicRandomVariable ):
294
350
default_output = 1
295
351
_print_name = ("KalmanFilter" , "\\ operatorname{KalmanFilter}" )
352
+ signature = "(t,s),(t,s,s),(t),[rng]->[rng],(t,s)"
296
353
297
354
def update (self , node : Node ):
298
355
return {node .inputs [- 1 ]: node .outputs [0 ]}
@@ -325,48 +382,45 @@ def dist(cls, mus, covs, logp, support_shape=None, **kwargs):
325
382
if support_shape is None :
326
383
support_shape = pt .as_tensor_variable (())
327
384
328
- steps = pm .intX (mus .shape [0 ])
329
-
330
- return super ().dist ([mus , covs , logp , steps , support_shape ], ** kwargs )
385
+ return super ().dist ([mus , covs , logp , support_shape ], ** kwargs )
331
386
332
387
@classmethod
333
- def rv_op (cls , mus , covs , logp , steps , support_shape , size = None ):
388
+ def rv_op (cls , mus , covs , logp , support_shape , size = None ):
334
389
if size is not None :
335
390
batch_size = size
336
391
else :
337
392
batch_size = support_shape
338
393
339
- # mus_, covs_ = mus.type(), covs.type()
340
394
mus_ , covs_ , support_shape_ = mus .type (), covs .type (), support_shape .type ()
341
- steps_ = steps .type ()
342
- logp_ = logp .type ()
343
395
396
+ logp_ = logp .type ()
344
397
rng = pytensor .shared (np .random .default_rng ())
345
398
346
399
def step (mu , cov , rng ):
347
- new_rng , mvn = pm . MvNormal .dist (mu = mu , cov = cov , rng = rng , size = batch_size ).owner .outputs
400
+ new_rng , mvn = MvNormalSVD .dist (mu = mu , cov = cov , rng = rng , size = batch_size ).owner .outputs
348
401
return mvn , {rng : new_rng }
349
402
350
403
mvn_seq , updates = pytensor .scan (
351
- step , sequences = [mus_ , covs_ ], non_sequences = [rng ], n_steps = steps_ , strict = True
404
+ step , sequences = [mus_ , covs_ ], non_sequences = [rng ], strict = True
352
405
)
406
+ mvn_seq = pt .specify_shape (mvn_seq , mus .type .shape )
353
407
354
408
(seq_mvn_rng ,) = tuple (updates .values ())
355
409
356
410
mvn_seq_op = KalmanFilterRV (
357
- inputs = [mus_ , covs_ , logp_ , steps_ , rng ], outputs = [seq_mvn_rng , mvn_seq ], ndim_supp = 2
411
+ inputs = [mus_ , covs_ , logp_ , rng ], outputs = [seq_mvn_rng , mvn_seq ], ndim_supp = 2
358
412
)
359
413
360
- mvn_seq = mvn_seq_op (mus , covs , logp , steps , rng )
414
+ mvn_seq = mvn_seq_op (mus , covs , logp , rng )
415
+
361
416
return mvn_seq
362
417
363
418
364
419
@_logprob .register (KalmanFilterRV )
365
- def sequence_mvnormal_logp (op , values , mus , covs , logp , steps , rng , ** kwargs ):
420
+ def sequence_mvnormal_logp (op , values , mus , covs , logp , rng , ** kwargs ):
366
421
return check_parameters (
367
422
logp ,
368
- pt .eq (values [0 ].shape [0 ], steps ),
369
- pt .eq (mus .shape [0 ], steps ),
370
- pt .eq (covs .shape [0 ], steps ),
423
+ pt .eq (values [0 ].shape [0 ], mus .shape [0 ]),
424
+ pt .eq (covs .shape [0 ], mus .shape [0 ]),
371
425
msg = "Observed data and parameters must have the same number of timesteps (dimension 0)" ,
372
426
)
0 commit comments