Skip to content

Commit 6ac1d0d

Browse files
Fix shape of sigma_irregular in create_structural_model_and_equivalent_statsmodel
Add measurement noise to (un)conditional observed distributions SARIMAX tests don't depend on parameters being wrapped in `atleast_1d` inside `PyMCStateSpace.insert_random_variables()` Infer parameter shapes from symbolic placeholders rather than `self.param_info` Don't wrap all placeholders with `pt.atleast_1d` before applying `clone_replace` Add properties for exogenous data, insert exogenous data into model separately from parameters Use absolute path to test data Refactor `PyMCStateSpace` methods to no longer expect matrices in provided `idata` Don't add statespace matrices or outputs to PyMC graph by default.
1 parent 1c40824 commit 6ac1d0d

File tree

12 files changed

+488
-291
lines changed

12 files changed

+488
-291
lines changed

pymc_experimental/statespace/core/statespace.py

Lines changed: 326 additions & 207 deletions
Large diffs are not rendered by default.

pymc_experimental/statespace/filters/distributions.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -290,16 +290,16 @@ def dist(cls, a0, P0, c, d, T, Z, R, H, Q, *, steps=None, **kwargs):
290290
return latent_states, obs_states
291291

292292

293-
class SequenceMvNormalRV(SymbolicRandomVariable):
293+
class KalmanFilterRV(SymbolicRandomVariable):
294294
default_output = 1
295-
_print_name = ("SequenceMvNormal", "\\operatorname{SequenceMvNormal}")
295+
_print_name = ("KalmanFilter", "\\operatorname{KalmanFilter}")
296296

297297
def update(self, node: Node):
298298
return {node.inputs[-1]: node.outputs[0]}
299299

300300

301301
class SequenceMvNormal(Continuous):
302-
rv_op = SequenceMvNormalRV
302+
rv_op = KalmanFilterRV
303303

304304
def __new__(cls, *args, **kwargs):
305305
support_shape = get_support_shape(
@@ -336,6 +336,7 @@ def rv_op(cls, mus, covs, logp, steps, support_shape, size=None):
336336
else:
337337
batch_size = support_shape
338338

339+
# mus_, covs_ = mus.type(), covs.type()
339340
mus_, covs_, support_shape_ = mus.type(), covs.type(), support_shape.type()
340341
steps_ = steps.type()
341342
logp_ = logp.type()
@@ -352,15 +353,15 @@ def step(mu, cov, rng):
352353

353354
(seq_mvn_rng,) = tuple(updates.values())
354355

355-
mvn_seq_op = SequenceMvNormalRV(
356+
mvn_seq_op = KalmanFilterRV(
356357
inputs=[mus_, covs_, logp_, steps_], outputs=[seq_mvn_rng, mvn_seq], ndim_supp=2
357358
)
358359

359360
mvn_seq = mvn_seq_op(mus, covs, logp, steps)
360361
return mvn_seq
361362

362363

363-
@_logprob.register(SequenceMvNormalRV)
364+
@_logprob.register(KalmanFilterRV)
364365
def sequence_mvnormal_logp(op, values, mus, covs, logp, steps, rng, **kwargs):
365366
return check_parameters(
366367
logp,

pymc_experimental/statespace/filters/kalman_smoother.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ def build_graph(
100100
[smoothed_covariances[::-1], pt.expand_dims(P_last, axis=(0,))], axis=0
101101
)
102102

103+
smoothed_states.name = "smoothed_states"
104+
smoothed_covariances.name = "smoothed_covariances"
105+
103106
return smoothed_states, smoothed_covariances
104107

105108
def smoother_step(self, *args):

pymc_experimental/statespace/models/SARIMAX.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -269,11 +269,11 @@ def param_info(self) -> Dict[str, Dict[str, Any]]:
269269
"constraints": "Positive Semi-definite",
270270
},
271271
"sigma_obs": {
272-
"shape": (self.k_endog,),
272+
"shape": None if self.k_endog == 1 else (self.k_endog,),
273273
"constraints": "Positive",
274274
},
275275
"sigma_state": {
276-
"shape": (self.k_posdef,),
276+
"shape": None if self.k_posdef == 1 else (self.k_posdef,),
277277
"constraints": "Positive",
278278
},
279279
"ar_params": {
@@ -330,8 +330,9 @@ def param_dims(self):
330330
"seasonal_ar_params": (SEASONAL_AR_PARAM_DIM,),
331331
"seasonal_ma_params": (SEASONAL_MA_PARAM_DIM,),
332332
}
333-
334-
if not self.measurement_error:
333+
if self.k_endog == 1:
334+
del coord_map["sigma_state"]
335+
if not self.measurement_error or self.k_endog == 1:
335336
del coord_map["sigma_obs"]
336337
if self.p == 0:
337338
del coord_map["ar_params"]
@@ -512,14 +513,14 @@ def make_symbolic_graph(self) -> None:
512513
# Set up the state covariance matrix
513514
state_cov_idx = ("state_cov",) + np.diag_indices(self.k_posdef)
514515
state_cov = self.make_and_register_variable(
515-
"sigma_state", shape=(self.k_posdef,), dtype=floatX
516+
"sigma_state", shape=() if self.k_posdef == 1 else (self.k_posdef,), dtype=floatX
516517
)
517518
self.ssm[state_cov_idx] = state_cov**2
518519

519520
if self.measurement_error:
520521
obs_cov_idx = ("obs_cov",) + np.diag_indices(self.k_endog)
521522
obs_cov = self.make_and_register_variable(
522-
"sigma_obs", shape=(self.k_endog,), dtype=floatX
523+
"sigma_obs", shape=() if self.k_endog == 1 else (self.k_endog,), dtype=floatX
523524
)
524525
self.ssm[obs_cov_idx] = obs_cov**2
525526

pymc_experimental/statespace/models/structural.py

Lines changed: 88 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,18 @@ def __init__(
8181
self,
8282
ssm: PytensorRepresentation,
8383
state_names,
84+
data_names,
8485
shock_names,
8586
param_names,
8687
exog_names,
8788
param_dims,
8889
coords,
8990
param_info,
91+
data_info,
9092
component_info,
9193
measurement_error,
9294
name_to_variable,
95+
name_to_data,
9396
name=None,
9497
verbose=True,
9598
filter_type: str = "standard",
@@ -104,6 +107,7 @@ def __init__(
104107
param_names, param_dims, param_info, k_states
105108
)
106109
self._state_names = state_names
110+
self._data_names = data_names
107111
self._shock_names = shock_names
108112
self._param_names = param_names
109113
self._param_dims = param_dims
@@ -113,6 +117,7 @@ def __init__(
113117

114118
self._coords = coords
115119
self._param_info = param_info
120+
self._data_info = data_info
116121
self.measurement_error = measurement_error
117122

118123
super().__init__(
@@ -126,7 +131,10 @@ def __init__(
126131

127132
self.ssm = ssm
128133
self._component_info = component_info
134+
129135
self._name_to_variable = name_to_variable
136+
self._name_to_data = name_to_data
137+
130138
self._exog_names = exog_names
131139
self._needs_exog_data = len(exog_names) > 0
132140

@@ -149,6 +157,10 @@ def _add_inital_state_cov_to_properties(param_names, param_dims, param_info, k_s
149157
def param_names(self):
150158
return self._param_names
151159

160+
@property
161+
def data_names(self) -> List[str]:
162+
return self._data_names
163+
152164
@property
153165
def state_names(self):
154166
return self._state_names
@@ -173,6 +185,10 @@ def coords(self) -> Dict[str, Sequence]:
173185
def param_info(self) -> Dict[str, Dict[str, Any]]:
174186
return self._param_info
175187

188+
@property
189+
def data_info(self) -> dict[str, dict[str, Any]]:
190+
return self._data_info
191+
176192
def make_symbolic_graph(self) -> None:
177193
"""
178194
Assign placeholder pytensor variables among statespace matrices in positions where PyMC variables will go.
@@ -338,6 +354,7 @@ def __init__(
338354
k_states,
339355
k_posdef,
340356
state_names=None,
357+
data_names=None,
341358
shock_names=None,
342359
param_names=None,
343360
exog_names=None,
@@ -354,14 +371,18 @@ def __init__(
354371
self.measurement_error = measurement_error
355372

356373
self.state_names = state_names if state_names is not None else []
374+
self.data_names = data_names if data_names is not None else []
357375
self.shock_names = shock_names if shock_names is not None else []
358376
self.param_names = param_names if param_names is not None else []
359377
self.exog_names = exog_names if exog_names is not None else []
360378

361379
self.needs_exog_data = len(self.exog_names) > 0
362380
self.coords = {}
363381
self.param_dims = {}
382+
364383
self.param_info = {}
384+
self.data_info = {}
385+
365386
self.param_counts = {}
366387

367388
if representation is None:
@@ -370,6 +391,7 @@ def __init__(
370391
self.ssm = representation
371392

372393
self._name_to_variable = {}
394+
self._name_to_data = {}
373395

374396
if not component_from_sum:
375397
self.populate_component_properties()
@@ -429,6 +451,43 @@ def make_and_register_variable(self, name, shape, dtype=floatX) -> Variable:
429451
self._name_to_variable[name] = placeholder
430452
return placeholder
431453

454+
def make_and_register_data(self, name, shape, dtype=floatX) -> Variable:
455+
r"""
456+
Helper function to create a pytensor symbolic variable and register it in the _name_to_data dictionary
457+
458+
Parameters
459+
----------
460+
name : str
461+
The name of the placeholder data. Must be the name of an expected data variable.
462+
shape : int or tuple of int
463+
Shape of the parameter
464+
dtype : str, default pytensor.config.floatX
465+
dtype of the parameter
466+
467+
Notes
468+
-----
469+
See docstring for make_and_register_variable for more details. This function is similar, but handles data
470+
inputs instead of model parameters.
471+
472+
An error is raised if the provided name has already been registered, or if the name is not present in the
473+
``data_names`` property.
474+
"""
475+
if name not in self.data_names:
476+
raise ValueError(
477+
f"{name} is not a model parameter. All placeholder variables should correspond to model "
478+
f"parameters."
479+
)
480+
481+
if name in self._name_to_data.keys():
482+
raise ValueError(
483+
f"{name} is already a registered placeholder variable with shape "
484+
f"{self._name_to_data[name].type.shape}"
485+
)
486+
487+
placeholder = pt.tensor(name, shape=shape, dtype=dtype)
488+
self._name_to_data[name] = placeholder
489+
return placeholder
490+
432491
def make_symbolic_graph(self) -> None:
433492
raise NotImplementedError
434493

@@ -481,7 +540,6 @@ def make_slice(name, x, o_x):
481540
transition.name = T.name
482541

483542
design = pt.concatenate(conform_time_varying_and_time_invariant_matrices(Z, o_Z), axis=-1)
484-
485543
design.name = Z.name
486544

487545
selection = block_diagonal([R, o_R])
@@ -542,14 +600,18 @@ def _make_combined_name(self):
542600

543601
def __add__(self, other):
544602
state_names = self._combine_property(other, "state_names")
603+
data_names = self._combine_property(other, "data_names")
545604
param_names = self._combine_property(other, "param_names")
546605
shock_names = self._combine_property(other, "shock_names")
547606
param_info = self._combine_property(other, "param_info")
607+
data_info = self._combine_property(other, "data_info")
548608
param_dims = self._combine_property(other, "param_dims")
549609
coords = self._combine_property(other, "coords")
550610
exog_names = self._combine_property(other, "exog_names")
551611

552612
_name_to_variable = self._combine_property(other, "_name_to_variable")
613+
_name_to_data = self._combine_property(other, "_name_to_data")
614+
553615
measurement_error = any([self.measurement_error, other.measurement_error])
554616

555617
k_states, k_posdef, k_endog = self._get_combined_shapes(other)
@@ -567,30 +629,22 @@ def __add__(self, other):
567629
new_comp._component_info = self._combine_component_info(other)
568630
new_comp.name = new_comp._make_combined_name()
569631

570-
property_names = [
571-
"state_names",
572-
"param_names",
573-
"shock_names",
574-
"state_dims",
575-
"coords",
576-
"param_dims",
577-
"param_info",
578-
"exog_names",
579-
"_name_to_variable",
580-
]
581-
property_values = [
582-
state_names,
583-
param_names,
584-
shock_names,
585-
param_dims,
586-
coords,
587-
param_dims,
588-
param_info,
589-
exog_names,
590-
_name_to_variable,
632+
names_and_props = [
633+
("state_names", state_names),
634+
("data_names", data_names),
635+
("param_names", param_names),
636+
("shock_names", shock_names),
637+
("param_dims", param_dims),
638+
("coords", coords),
639+
("param_dims", param_dims),
640+
("param_info", param_info),
641+
("data_info", data_info),
642+
("exog_names", exog_names),
643+
("_name_to_variable", _name_to_variable),
644+
("_name_to_data", _name_to_data),
591645
]
592646

593-
for prop, value in zip(property_names, property_values):
647+
for prop, value in names_and_props:
594648
setattr(new_comp, prop, value)
595649

596650
return new_comp
@@ -622,15 +676,18 @@ def build(self, name=None, filter_type="standard", verbose=True):
622676
self.ssm,
623677
name=name,
624678
state_names=self.state_names,
679+
data_names=self.data_names,
625680
shock_names=self.shock_names,
626681
param_names=self.param_names,
627682
param_dims=self.param_dims,
628683
coords=self.coords,
629684
param_info=self.param_info,
685+
data_info=self.data_info,
630686
component_info=self._component_info,
631687
measurement_error=self.measurement_error,
632688
exog_names=self.exog_names,
633689
name_to_variable=self._name_to_variable,
690+
name_to_data=self._name_to_data,
634691
filter_type=filter_type,
635692
verbose=verbose,
636693
)
@@ -881,7 +938,8 @@ def populate_component_properties(self):
881938
}
882939

883940
def make_symbolic_graph(self) -> None:
884-
error_sigma = self.make_and_register_variable(f"sigma_{self.name}", shape=(self.k_endog,))
941+
sigma_shape = () if self.k_endog == 1 else (self.k_endog,)
942+
error_sigma = self.make_and_register_variable(f"sigma_{self.name}", shape=sigma_shape)
885943
diag_idx = np.diag_indices(self.k_endog)
886944
idx = np.s_["obs_cov", diag_idx[0], diag_idx[1]]
887945
self.ssm[idx] = error_sigma**2
@@ -1541,7 +1599,7 @@ def _handle_input_data(self, k_exog: int, state_names: Optional[List[str]], name
15411599

15421600
def make_symbolic_graph(self) -> None:
15431601
betas = self.make_and_register_variable(f"beta_{self.name}", shape=(self.k_states,))
1544-
regression_data = self.make_and_register_variable(
1602+
regression_data = self.make_and_register_data(
15451603
f"data_{self.name}", shape=(None, self.k_states)
15461604
)
15471605

@@ -1560,17 +1618,19 @@ def make_symbolic_graph(self) -> None:
15601618
def populate_component_properties(self) -> None:
15611619
self.shock_names = self.state_names
15621620

1563-
self.param_names = [f"beta_{self.name}", f"data_{self.name}"]
1621+
self.param_names = [f"beta_{self.name}"]
1622+
self.data_names = [f"data_{self.name}"]
15641623
self.param_dims = {
15651624
f"beta_{self.name}": ("exog_state",),
1566-
f"data_{self.name}": (TIME_DIM, "exog_state"),
15671625
}
15681626

15691627
self.param_info = {
15701628
f"beta_{self.name}": {"shape": (1,), "constraints": None, "dims": ("exog_state",)},
1629+
}
1630+
1631+
self.data_info = {
15711632
f"data_{self.name}": {
15721633
"shape": (None, self.k_states),
1573-
"constraints": None,
15741634
"dims": (TIME_DIM, "exog_state"),
15751635
},
15761636
}

pymc_experimental/tests/statespace/test_SARIMAX.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def pymc_mod(arima_mod):
180180
ar_params = pm.Normal("ar_params", sigma=0.1, dims=["ar_lag"])
181181
ma_params = pm.Normal("ma_params", sigma=1, dims=["ma_lag"])
182182
sigma_state = pm.Exponential("sigma_state", 0.5)
183-
arima_mod.build_statespace_graph(data=data)
183+
arima_mod.build_statespace_graph(data=data, save_kalman_filter_outputs_in_idata=True)
184184

185185
return pymc_mod
186186

@@ -210,7 +210,8 @@ def pymc_mod_interp(arima_mod_interp):
210210
ma_params = pm.Normal("ma_params", sigma=1, dims=["ma_lag"])
211211
sigma_state = pm.Exponential("sigma_state", 0.5)
212212
sigma_obs = pm.Exponential("sigma_obs", 0.1)
213-
arima_mod_interp.build_statespace_graph(data=data)
213+
214+
arima_mod_interp.build_statespace_graph(data=data, save_kalman_filter_outputs_in_idata=True)
214215

215216
return pymc_mod
216217

0 commit comments

Comments
 (0)