Skip to content

Commit dede2b3

Browse files
Add helper function to sample prior/posterior statespace matrices
1 parent 8950d56 commit dede2b3

File tree

1 file changed

+80
-11
lines changed

1 file changed

+80
-11
lines changed

pymc_experimental/statespace/core/statespace.py

Lines changed: 80 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ class PyMCStateSpace:
145145
\end{align}
146146
147147
With the remaining statespace matrices as zero matrices of the appropriate sizes. The model has two states,
148-
two shocks, and one observed state. Knowing all this, a very simple local level model can be implemented as
148+
two shocks, and one observed state. Knowing all this, a very simple local level model can be implemented as
149149
follows:
150150
151151
.. code:: python
@@ -167,17 +167,18 @@ def __init__():
167167
def param_names(self):
168168
return ['x0', 'P0', 'sigma_nu', 'sigma_eta']
169169
170-
def update(self, theta, mode=None):
171-
# Since the param_names are ['x0', 'P0', 'sigma_nu', 'sigma_eta'], theta will come in as
172-
# [x0.ravel(), P0.ravel(), sigma_nu, sigma_eta]
173-
# It will have length 2 + 4 + 1 + 1 = 8
170+
def make_symbolic_graph(self):
171+
# Declare symbolic variables that represent parameters of the model
172+
# In this case, we have 4: x0 (initial state), P0 (initial state covariance), sigma_nu, and sigma_eta
174173
175-
x0 = theta[:2]
176-
P0 = theta[2:6].reshape(2,2)
177-
sigma_nu = theta[6]
178-
sigma_eta = theta[7]
174+
x0 = self.make_and_register_variable('x0', shape=(2,))
175+
P0 = self.make_and_register_variable('P0', shape=(2,2))
176+
sigma_mu = self.make_and_register_variable('sigma_nu')
177+
sigma_eta = self.make_and_register_variable('sigma_eta')
178+
179+
# Next, use these symbolic variables to build the statespace matrices by assigning each parameter
180+
# to its correct location in the correct matrix
179181
180-
# Assign parameters to their correct locations
181182
self.ssm['initial_state', :] = x0
182183
self.ssm['initial_state_cov', :, :] = P0
183184
self.ssm['state_cov', 0, 0] = sigma_nu
@@ -443,7 +444,9 @@ def add_default_priors(self) -> None:
443444
"""
444445
raise NotImplementedError("The add_default_priors property has not been implemented!")
445446

446-
def make_and_register_variable(self, name, shape, dtype=floatX) -> Variable:
447+
def make_and_register_variable(
448+
self, name, shape: int | tuple[int] | None = None, dtype=floatX
449+
) -> Variable:
447450
"""
448451
Helper function to create a pytensor symbolic variable and register it in the _name_to_variable dictionary
449452
@@ -1221,6 +1224,12 @@ def _sample_unconditional(
12211224
with pm.Model(coords=temp_coords if dims is not None else None) as forward_model:
12221225
self._build_dummy_graph()
12231226
self._insert_random_variables()
1227+
1228+
for name in self.data_names:
1229+
pm.Data(**self._exog_data_info[name])
1230+
1231+
self._insert_data_variables()
1232+
12241233
matrices = [x0, P0, c, d, T, Z, R, H, Q] = self.unpack_statespace()
12251234

12261235
if not self.measurement_error:
@@ -1419,6 +1428,66 @@ def sample_unconditional_posterior(
14191428
idata, "posterior", steps, use_data_time_dim, random_seed, **kwargs
14201429
)
14211430

1431+
def sample_statespace_matrices(
1432+
self, idata, matrix_names: str | list[str] | None, group: str = "posterior"
1433+
):
1434+
"""
1435+
Draw samples of requested statespace matrices from provided idata
1436+
1437+
Parameters
1438+
----------
1439+
matrix_names: str, list[str], optional
1440+
Statespace matrices to be sampled. Valid names are short names: x0, P0, c, d, T, Z, R, H, Q, or
1441+
"formal" names: initial_state, initial_state_cov, state_intercept, obs_intercept, transition, design,
1442+
selection, obs_cov, state_cov
1443+
idata: az.InferenceData
1444+
InferenceData from which to sample
1445+
1446+
group: str, one of "posterior" or "prior"
1447+
Whether to sample from priors or posteriors
1448+
1449+
Returns
1450+
-------
1451+
idata_matrices: az.InterenceData
1452+
"""
1453+
_verify_group(group)
1454+
1455+
if matrix_names is None:
1456+
matrix_names = MATRIX_NAMES
1457+
elif isinstance(matrix_names, str):
1458+
matrix_names = [matrix_names]
1459+
1460+
with pm.Model(coords=self._fit_coords) as forward_model:
1461+
self._build_dummy_graph()
1462+
self._insert_random_variables()
1463+
1464+
for name in self.data_names:
1465+
pm.Data(**self._exog_data_info[name])
1466+
1467+
self._insert_data_variables()
1468+
matrices = self.unpack_statespace()
1469+
for short_name, matrix in zip(MATRIX_NAMES, matrices):
1470+
long_name = SHORT_NAME_TO_LONG[short_name]
1471+
if (long_name in matrix_names) or (short_name in matrix_names):
1472+
name = long_name if long_name in matrix_names else short_name
1473+
dims = [x if x in self._fit_coords else None for x in MATRIX_DIMS[short_name]]
1474+
pm.Deterministic(name, matrix, dims=dims)
1475+
1476+
# TODO: Remove this after pm.Flat has its initial_value fixed
1477+
forward_model.rvs_to_initial_values = {
1478+
rv: None for rv in forward_model.rvs_to_initial_values.keys()
1479+
}
1480+
frozen_model = freeze_dims_and_data(forward_model)
1481+
with frozen_model:
1482+
matrix_idata = pm.sample_posterior_predictive(
1483+
idata if group == "posterior" else idata.prior,
1484+
var_names=matrix_names,
1485+
compile_kwargs={"mode": self._fit_mode},
1486+
extend_inferencedata=False,
1487+
)
1488+
1489+
return matrix_idata
1490+
14221491
def forecast(
14231492
self,
14241493
idata: InferenceData,

0 commit comments

Comments
 (0)