Skip to content

Commit 727c032

Browse files
Bug-fixes and changes to statespace distributions
Remove tests related to the `add_exogenous` method Add dummy `MvNormalSVDRV` for forward jax sampling with `method="SVD"` Dynamically generate `LinearGaussianStateSpaceRV` signature from inputs Add signature and simple test for `SequenceMvNormal`
1 parent 3249729 commit 727c032

File tree

2 files changed

+196
-81
lines changed

2 files changed

+196
-81
lines changed

pymc_experimental/statespace/filters/distributions.py

Lines changed: 127 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1+
import jax.random
12
import numpy as np
23
import pymc as pm
34
import pytensor
45
import pytensor.tensor as pt
56
from pymc import intX
67
from pymc.distributions.dist_math import check_parameters
78
from pymc.distributions.distribution import Continuous, SymbolicRandomVariable
9+
from pymc.distributions.multivariate import MvNormal
810
from pymc.distributions.shape_utils import get_support_shape, get_support_shape_1d
911
from pymc.logprob.abstract import _logprob
1012
from pytensor.graph.basic import Node
13+
from pytensor.link.jax.dispatch.random import jax_sample_fn
14+
from pytensor.tensor.random.basic import MvNormalRV
1115

1216
floatX = pytensor.config.floatX
1317
COV_ZERO_TOL = 0
@@ -18,6 +22,64 @@
1822
)
1923

2024

25+
def make_signature(sequence_names):
26+
states = "s"
27+
obs = "p"
28+
exog = "r"
29+
time = "t"
30+
state_and_obs = "n"
31+
32+
matrix_to_shape = {
33+
"x0": (states,),
34+
"P0": (states, states),
35+
"c": (states,),
36+
"d": (obs,),
37+
"T": (states, states),
38+
"Z": (obs, states),
39+
"R": (states, exog),
40+
"H": (obs, obs),
41+
"Q": (exog, exog),
42+
}
43+
44+
for matrix in sequence_names:
45+
base_shape = matrix_to_shape[matrix]
46+
matrix_to_shape[matrix] = (time,) + base_shape
47+
48+
signature = ",".join(["(" + ",".join(shapes) + ")" for shapes in matrix_to_shape.values()])
49+
50+
return f"{signature},[rng]->[rng],({time},{state_and_obs})"
51+
52+
53+
class MvNormalSVDRV(MvNormalRV):
54+
name = "multivariate_normal"
55+
signature = "(n),(n,n)->(n)"
56+
dtype = "floatX"
57+
_print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}")
58+
59+
60+
class MvNormalSVD(MvNormal):
61+
"""Dummy distribution intended to be rewritten into a JAX multivariate_normal with method="svd".
62+
63+
A JAX MvNormal robust to low-rank covariance matrices
64+
"""
65+
66+
rv_op = MvNormalSVDRV()
67+
68+
69+
@jax_sample_fn.register(MvNormalSVDRV)
70+
def jax_sample_fn_mvnormal_svd(op, node):
71+
def sample_fn(rng, size, dtype, *parameters):
72+
rng_key = rng["jax_state"]
73+
rng_key, sampling_key = jax.random.split(rng_key, 2)
74+
sample = jax.random.multivariate_normal(
75+
sampling_key, *parameters, shape=size, dtype=dtype, method="svd"
76+
)
77+
rng["jax_state"] = rng_key
78+
return (rng, sample)
79+
80+
return sample_fn
81+
82+
2183
class LinearGaussianStateSpaceRV(SymbolicRandomVariable):
2284
default_output = 1
2385
_print_name = ("LinearGuassianStateSpace", "\\operatorname{LinearGuassianStateSpace}")
@@ -28,6 +90,7 @@ def update(self, node: Node):
2890

2991
class _LinearGaussianStateSpace(Continuous):
3092
rv_op = LinearGaussianStateSpaceRV
93+
ndim_supp = 2
3194

3295
def __new__(
3396
cls,
@@ -91,25 +154,8 @@ def dist(
91154
[a0, P0, c, d, T, Z, R, H, Q, steps], mode=mode, sequence_names=sequence_names, **kwargs
92155
)
93156

94-
@classmethod
95-
def _get_k_states(cls, T):
96-
k_states = T.type.shape[0]
97-
if k_states is None:
98-
raise ValueError(lgss_shape_message)
99-
return k_states
100-
101-
@classmethod
102-
def _get_k_endog(cls, H):
103-
k_endog = H.type.shape[0]
104-
if k_endog is None:
105-
raise ValueError(lgss_shape_message)
106-
107-
return k_endog
108-
109157
@classmethod
110158
def rv_op(cls, a0, P0, c, d, T, Z, R, H, Q, steps, size=None, mode=None, sequence_names=None):
111-
if size is not None:
112-
batch_size = size
113159
if sequence_names is None:
114160
sequence_names = []
115161

@@ -125,77 +171,78 @@ def rv_op(cls, a0, P0, c, d, T, Z, R, H, Q, steps, size=None, mode=None, sequenc
125171
H_.name = "H"
126172
Q_.name = "Q"
127173

128-
n_seq = len(sequence_names)
129174
sequences = [
130175
x
131176
for x, name in zip([c_, d_, T_, Z_, R_, H_, Q_], ["c", "d", "T", "Z", "R", "H", "Q"])
132177
if name in sequence_names
133178
]
134179
non_sequences = [x for x in [c_, d_, T_, Z_, R_, H_, Q_] if x not in sequences]
135180

136-
steps_ = steps.type()
137181
rng = pytensor.shared(np.random.default_rng())
138182

139183
def sort_args(args):
140184
sorted_args = []
185+
186+
# Inside the scan, outputs_info variables get a time step appended to their name
187+
# e.g. x -> x[t]. Remove this so we can identify variables by name.
141188
arg_names = [x.name.replace("[t]", "") for x in args]
142189

190+
# c, d ,T, Z, R, H, Q is the "canonical" ordering
143191
for name in ["c", "d", "T", "Z", "R", "H", "Q"]:
144192
idx = arg_names.index(name)
145193
sorted_args.append(args[idx])
146194

147195
return sorted_args
148196

197+
n_seq = len(sequence_names)
198+
149199
def step_fn(*args):
150200
seqs, state, non_seqs = args[:n_seq], args[n_seq], args[n_seq + 1 :]
151201
non_seqs, rng = non_seqs[:-1], non_seqs[-1]
152202

153203
c, d, T, Z, R, H, Q = sort_args(seqs + non_seqs)
154-
155204
k = T.shape[0]
156205
a = state[:k]
157206

158-
middle_rng, a_innovation = pm.MvNormal.dist(mu=0, cov=Q, rng=rng).owner.outputs
159-
next_rng, y_innovation = pm.MvNormal.dist(mu=0, cov=H, rng=middle_rng).owner.outputs
207+
middle_rng, a_innovation = MvNormalSVD.dist(mu=0, cov=Q, rng=rng).owner.outputs
208+
next_rng, y_innovation = MvNormalSVD.dist(mu=0, cov=H, rng=middle_rng).owner.outputs
160209

161210
a_mu = c + T @ a
162-
a_next = pt.switch(pt.all(pt.le(Q, COV_ZERO_TOL)), a_mu, a_mu + R @ a_innovation)
211+
a_next = a_mu + R @ a_innovation
163212

164213
y_mu = d + Z @ a_next
165-
y_next = pt.switch(pt.all(pt.le(H, COV_ZERO_TOL)), y_mu, y_mu + y_innovation)
214+
y_next = y_mu + y_innovation
166215

167216
next_state = pt.concatenate([a_next, y_next], axis=0)
168217

169218
return next_state, {rng: next_rng}
170219

171-
init_x_ = pm.MvNormal.dist(a0_, P0_, rng=rng)
172220
Z_init = Z_ if Z_ in non_sequences else Z_[0]
173221
H_init = H_ if H_ in non_sequences else H_[0]
174222

175-
init_y_ = pt.switch(
176-
pt.all(pt.le(H_init, COV_ZERO_TOL)),
177-
Z_init @ init_x_,
178-
pm.MvNormal.dist(Z_init @ init_x_, H_init, rng=rng),
179-
)
223+
init_x_ = MvNormalSVD.dist(a0_, P0_, rng=rng)
224+
init_y_ = MvNormalSVD.dist(Z_init @ init_x_, H_init, rng=rng)
225+
180226
init_dist_ = pt.concatenate([init_x_, init_y_], axis=0)
181227

182228
statespace, updates = pytensor.scan(
183229
step_fn,
184230
outputs_info=[init_dist_],
185231
sequences=None if len(sequences) == 0 else sequences,
186232
non_sequences=non_sequences + [rng],
187-
n_steps=steps_,
233+
n_steps=steps,
188234
mode=mode,
189235
strict=True,
190236
)
191237

192238
statespace_ = pt.concatenate([init_dist_[None], statespace], axis=0)
239+
statespace_ = pt.specify_shape(statespace_, (steps + 1, None))
193240

194241
(ss_rng,) = tuple(updates.values())
195242
linear_gaussian_ss_op = LinearGaussianStateSpaceRV(
196-
inputs=[a0_, P0_, c_, d_, T_, Z_, R_, H_, Q_, steps_, rng],
243+
inputs=[a0_, P0_, c_, d_, T_, Z_, R_, H_, Q_, steps, rng],
197244
outputs=[ss_rng, statespace_],
198-
ndim_supp=1,
245+
signature=make_signature(sequence_names),
199246
)
200247

201248
linear_gaussian_ss = linear_gaussian_ss_op(a0, P0, c, d, T, Z, R, H, Q, steps, rng)
@@ -221,10 +268,10 @@ def __new__(
221268
H,
222269
Q,
223270
*,
224-
steps=None,
225-
mode=None,
226-
sequence_names=None,
271+
steps,
227272
k_endog=None,
273+
sequence_names=None,
274+
mode=None,
228275
**kwargs,
229276
):
230277
dims = kwargs.pop("dims", None)
@@ -239,35 +286,29 @@ def __new__(
239286
latent_dims = [time_dim, state_dim]
240287
obs_dims = [time_dim, obs_dim]
241288

242-
matrices = (a0, P0, c, d, T, Z, R, H, Q)
289+
matrices = ()
290+
243291
latent_obs_combined = _LinearGaussianStateSpace(
244292
f"{name}_combined",
245-
*matrices,
293+
a0,
294+
P0,
295+
c,
296+
d,
297+
T,
298+
Z,
299+
R,
300+
H,
301+
Q,
246302
steps=steps,
247303
mode=mode,
248304
sequence_names=sequence_names,
249305
**kwargs,
250306
)
251-
k_states = T.type.shape[0]
252-
253-
if k_endog is None and k_states is None:
254-
raise ValueError("Could not infer number of observed states, explicitly pass k_endog.")
255-
if k_endog is not None and k_states is not None:
256-
total_shape = latent_obs_combined.type.shape[-1]
257-
inferred_endog = total_shape - k_states
258-
if inferred_endog != k_endog:
259-
raise ValueError(
260-
f"Inferred k_endog does not agree with provided value ({inferred_endog} != {k_endog}). "
261-
f"It is not necessary to provide k_endog when the value can be inferred."
262-
)
263-
latent_slice = slice(None, -k_endog)
264-
obs_slice = slice(-k_endog, None)
265-
elif k_endog is None:
266-
latent_slice = slice(None, k_states)
267-
obs_slice = slice(k_states, None)
268-
else:
269-
latent_slice = slice(None, -k_endog)
270-
obs_slice = slice(-k_endog, None)
307+
latent_obs_combined = pt.specify_shape(latent_obs_combined, (steps + 1, None))
308+
if k_endog is None:
309+
k_endog = cls._get_k_endog(H)
310+
latent_slice = slice(None, -k_endog)
311+
obs_slice = slice(-k_endog, None)
271312

272313
latent_states = latent_obs_combined[..., latent_slice]
273314
obs_states = latent_obs_combined[..., obs_slice]
@@ -289,10 +330,26 @@ def dist(cls, a0, P0, c, d, T, Z, R, H, Q, *, steps=None, **kwargs):
289330

290331
return latent_states, obs_states
291332

333+
@classmethod
334+
def _get_k_states(cls, T):
335+
k_states = T.type.shape[0]
336+
if k_states is None:
337+
raise ValueError(lgss_shape_message)
338+
return k_states
339+
340+
@classmethod
341+
def _get_k_endog(cls, H):
342+
k_endog = H.type.shape[0]
343+
if k_endog is None:
344+
raise ValueError(lgss_shape_message)
345+
346+
return k_endog
347+
292348

293349
class KalmanFilterRV(SymbolicRandomVariable):
294350
default_output = 1
295351
_print_name = ("KalmanFilter", "\\operatorname{KalmanFilter}")
352+
signature = "(t,s),(t,s,s),(t),[rng]->[rng],(t,s)"
296353

297354
def update(self, node: Node):
298355
return {node.inputs[-1]: node.outputs[0]}
@@ -325,48 +382,45 @@ def dist(cls, mus, covs, logp, support_shape=None, **kwargs):
325382
if support_shape is None:
326383
support_shape = pt.as_tensor_variable(())
327384

328-
steps = pm.intX(mus.shape[0])
329-
330-
return super().dist([mus, covs, logp, steps, support_shape], **kwargs)
385+
return super().dist([mus, covs, logp, support_shape], **kwargs)
331386

332387
@classmethod
333-
def rv_op(cls, mus, covs, logp, steps, support_shape, size=None):
388+
def rv_op(cls, mus, covs, logp, support_shape, size=None):
334389
if size is not None:
335390
batch_size = size
336391
else:
337392
batch_size = support_shape
338393

339-
# mus_, covs_ = mus.type(), covs.type()
340394
mus_, covs_, support_shape_ = mus.type(), covs.type(), support_shape.type()
341-
steps_ = steps.type()
342-
logp_ = logp.type()
343395

396+
logp_ = logp.type()
344397
rng = pytensor.shared(np.random.default_rng())
345398

346399
def step(mu, cov, rng):
347-
new_rng, mvn = pm.MvNormal.dist(mu=mu, cov=cov, rng=rng, size=batch_size).owner.outputs
400+
new_rng, mvn = MvNormalSVD.dist(mu=mu, cov=cov, rng=rng, size=batch_size).owner.outputs
348401
return mvn, {rng: new_rng}
349402

350403
mvn_seq, updates = pytensor.scan(
351-
step, sequences=[mus_, covs_], non_sequences=[rng], n_steps=steps_, strict=True
404+
step, sequences=[mus_, covs_], non_sequences=[rng], strict=True
352405
)
406+
mvn_seq = pt.specify_shape(mvn_seq, mus.type.shape)
353407

354408
(seq_mvn_rng,) = tuple(updates.values())
355409

356410
mvn_seq_op = KalmanFilterRV(
357-
inputs=[mus_, covs_, logp_, steps_, rng], outputs=[seq_mvn_rng, mvn_seq], ndim_supp=2
411+
inputs=[mus_, covs_, logp_, rng], outputs=[seq_mvn_rng, mvn_seq], ndim_supp=2
358412
)
359413

360-
mvn_seq = mvn_seq_op(mus, covs, logp, steps, rng)
414+
mvn_seq = mvn_seq_op(mus, covs, logp, rng)
415+
361416
return mvn_seq
362417

363418

364419
@_logprob.register(KalmanFilterRV)
365-
def sequence_mvnormal_logp(op, values, mus, covs, logp, steps, rng, **kwargs):
420+
def sequence_mvnormal_logp(op, values, mus, covs, logp, rng, **kwargs):
366421
return check_parameters(
367422
logp,
368-
pt.eq(values[0].shape[0], steps),
369-
pt.eq(mus.shape[0], steps),
370-
pt.eq(covs.shape[0], steps),
423+
pt.eq(values[0].shape[0], mus.shape[0]),
424+
pt.eq(covs.shape[0], mus.shape[0]),
371425
msg="Observed data and parameters must have the same number of timesteps (dimension 0)",
372426
)

0 commit comments

Comments
 (0)