Skip to content

Commit 0c806d5

Browse files
Fix shape of sigma_irregular in create_structural_model_and_equivalent_statsmodel
Add measurement noise to (un)conditional observed distributions SARIMAX tests don't depend on parameters being wrapped in `atleast_1d` inside `PyMCStateSpace.insert_random_variables()` Infer parameter shapes from symbolic placeholders rather than `self.param_info` Don't wrap all placeholders with `pt.atleast_1d` before applying `clone_replace` Add properties for exogenous data, insert exogenous data into model separately from parameters Use absolute path to test data Refactor `PyMCStateSpace` methods to no longer expect matrices in provided `idata` Don't add statespace matrices or outputs to PyMC graph by default.
1 parent 656b800 commit 0c806d5

File tree

12 files changed

+769
-364
lines changed

12 files changed

+769
-364
lines changed

pymc_experimental/statespace/core/statespace.py

Lines changed: 326 additions & 207 deletions
Large diffs are not rendered by default.

pymc_experimental/statespace/filters/distributions.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -290,16 +290,16 @@ def dist(cls, a0, P0, c, d, T, Z, R, H, Q, *, steps=None, **kwargs):
290290
return latent_states, obs_states
291291

292292

293-
class SequenceMvNormalRV(SymbolicRandomVariable):
293+
class KalmanFilterRV(SymbolicRandomVariable):
294294
default_output = 1
295-
_print_name = ("SequenceMvNormal", "\\operatorname{SequenceMvNormal}")
295+
_print_name = ("KalmanFilter", "\\operatorname{KalmanFilter}")
296296

297297
def update(self, node: Node):
298298
return {node.inputs[-1]: node.outputs[0]}
299299

300300

301301
class SequenceMvNormal(Continuous):
302-
rv_op = SequenceMvNormalRV
302+
rv_op = KalmanFilterRV
303303

304304
def __new__(cls, *args, **kwargs):
305305
support_shape = get_support_shape(
@@ -336,6 +336,7 @@ def rv_op(cls, mus, covs, logp, steps, support_shape, size=None):
336336
else:
337337
batch_size = support_shape
338338

339+
# mus_, covs_ = mus.type(), covs.type()
339340
mus_, covs_, support_shape_ = mus.type(), covs.type(), support_shape.type()
340341
steps_ = steps.type()
341342
logp_ = logp.type()
@@ -352,15 +353,15 @@ def step(mu, cov, rng):
352353

353354
(seq_mvn_rng,) = tuple(updates.values())
354355

355-
mvn_seq_op = SequenceMvNormalRV(
356+
mvn_seq_op = KalmanFilterRV(
356357
inputs=[mus_, covs_, logp_, steps_], outputs=[seq_mvn_rng, mvn_seq], ndim_supp=2
357358
)
358359

359360
mvn_seq = mvn_seq_op(mus, covs, logp, steps)
360361
return mvn_seq
361362

362363

363-
@_logprob.register(SequenceMvNormalRV)
364+
@_logprob.register(KalmanFilterRV)
364365
def sequence_mvnormal_logp(op, values, mus, covs, logp, steps, rng, **kwargs):
365366
return check_parameters(
366367
logp,

pymc_experimental/statespace/filters/kalman_smoother.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ def build_graph(
100100
[smoothed_covariances[::-1], pt.expand_dims(P_last, axis=(0,))], axis=0
101101
)
102102

103+
smoothed_states.name = "smoothed_states"
104+
smoothed_covariances.name = "smoothed_covariances"
105+
103106
return smoothed_states, smoothed_covariances
104107

105108
def smoother_step(self, *args):

pymc_experimental/statespace/models/SARIMAX.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -269,11 +269,11 @@ def param_info(self) -> Dict[str, Dict[str, Any]]:
269269
"constraints": "Positive Semi-definite",
270270
},
271271
"sigma_obs": {
272-
"shape": (self.k_endog,),
272+
"shape": None if self.k_endog == 1 else (self.k_endog,),
273273
"constraints": "Positive",
274274
},
275275
"sigma_state": {
276-
"shape": (self.k_posdef,),
276+
"shape": None if self.k_posdef == 1 else (self.k_posdef,),
277277
"constraints": "Positive",
278278
},
279279
"ar_params": {
@@ -330,8 +330,9 @@ def param_dims(self):
330330
"seasonal_ar_params": (SEASONAL_AR_PARAM_DIM,),
331331
"seasonal_ma_params": (SEASONAL_MA_PARAM_DIM,),
332332
}
333-
334-
if not self.measurement_error:
333+
if self.k_endog == 1:
334+
del coord_map["sigma_state"]
335+
if not self.measurement_error or self.k_endog == 1:
335336
del coord_map["sigma_obs"]
336337
if self.p == 0:
337338
del coord_map["ar_params"]
@@ -512,14 +513,14 @@ def make_symbolic_graph(self) -> None:
512513
# Set up the state covariance matrix
513514
state_cov_idx = ("state_cov",) + np.diag_indices(self.k_posdef)
514515
state_cov = self.make_and_register_variable(
515-
"sigma_state", shape=(self.k_posdef,), dtype=floatX
516+
"sigma_state", shape=() if self.k_posdef == 1 else (self.k_posdef,), dtype=floatX
516517
)
517518
self.ssm[state_cov_idx] = state_cov
518519

519520
if self.measurement_error:
520521
obs_cov_idx = ("obs_cov",) + np.diag_indices(self.k_endog)
521522
obs_cov = self.make_and_register_variable(
522-
"sigma_obs", shape=(self.k_endog,), dtype=floatX
523+
"sigma_obs", shape=() if self.k_endog == 1 else (self.k_endog,), dtype=floatX
523524
)
524525
self.ssm[obs_cov_idx] = obs_cov
525526

0 commit comments

Comments
 (0)