Skip to content

Commit 65b113e

Browse files
Allow forward sampling of statespace models in JAX mode
Explicitly set data shape to avoid broadcasting error Better handling of measurement error dims in `SARIMAX` models Freeze auxiliary models before forward sampling Bugfixes for posterior predictive sampling helpers Allow specification of time dimension name when registering data Save info about exogenous data for post-estimation tasks Restore `_exog_data_info` member variable Be more consistent with the names of filter outputs
1 parent dfe3fe0 commit 65b113e

File tree

5 files changed

+137
-77
lines changed

5 files changed

+137
-77
lines changed

pymc_experimental/statespace/core/statespace.py

Lines changed: 111 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytensor.tensor as pt
99
from arviz import InferenceData
1010
from pymc.model import modelcontext
11+
from pymc.model.transform.optimization import freeze_dims_and_data
1112
from pymc.util import RandomState
1213
from pytensor import Variable, graph_replace
1314
from pytensor.compile import get_mode
@@ -223,8 +224,9 @@ def __init__(
223224
self._fit_dims: Optional[dict[str, Sequence[str]]] = None
224225
self._fit_data: Optional[pt.TensorVariable] = None
225226

226-
self._needs_exog_data = False
227+
self._needs_exog_data = None
227228
self._exog_names = []
229+
self._exog_data_info = {}
228230
self._name_to_variable = {}
229231
self._name_to_data = {}
230232

@@ -348,7 +350,7 @@ def data_names(self) -> list[str]:
348350
This does not include the observed data series, which is automatically handled by PyMC. This property only
349351
needs to be implemented for models that expect exogenous data.
350352
"""
351-
raise NotImplementedError("The data_names property has not been implemented!")
353+
return []
352354

353355
@property
354356
def param_info(self) -> dict[str, dict[str, Any]]:
@@ -620,6 +622,19 @@ def _get_matrix_shape_and_dims(
620622

621623
return shape, dims
622624

625+
def _save_exogenous_data_info(self):
626+
"""
627+
Store exogenous data required by posterior sampling functions
628+
"""
629+
pymc_mod = modelcontext(None)
630+
for data_name in self.data_names:
631+
data = pymc_mod[data_name]
632+
self._exog_data_info[data_name] = {
633+
"name": data_name,
634+
"value": data.get_value(),
635+
"dims": pymc_mod.named_vars_to_dims.get(data_name, None),
636+
}
637+
623638
def _insert_random_variables(self):
624639
"""
625640
Replace pytensor symbolic variables with PyMC random variables.
@@ -743,63 +758,28 @@ def _register_matrices_with_pymc_model(self) -> list[pt.TensorVariable]:
743758
@staticmethod
744759
def _register_kalman_filter_outputs_with_pymc_model(outputs: tuple[pt.TensorVariable]) -> None:
745760
mod = modelcontext(None)
761+
coords = mod.coords
762+
746763
states, covs = outputs[:4], outputs[4:]
747764

748-
state_names = ["filtered_state", "predicted_state", "observed_state", "smoothed_state"]
765+
state_names = [
766+
"filtered_state",
767+
"predicted_state",
768+
"predicted_observed_state",
769+
"smoothed_state",
770+
]
749771
cov_names = [
750772
"filtered_covariance",
751773
"predicted_covariance",
752-
"observed_covariance",
774+
"predicted_observed_covariance",
753775
"smoothed_covariance",
754776
]
755777

756778
with mod:
757-
for state, name in zip(states, state_names):
758-
pm.Deterministic(name, state, dims=FILTER_OUTPUT_DIMS.get(name, None))
759-
760-
for cov, name in zip(covs, cov_names):
761-
pm.Deterministic(name, cov, dims=FILTER_OUTPUT_DIMS.get(name, None))
762-
763-
def add_exogenous(self, exog: pt.TensorVariable) -> None:
764-
"""
765-
Add an exogenous process to the statespace model
766-
767-
Parameters
768-
----------
769-
exog: TensorVariable
770-
An (N, k_endog) tensor representing exogenous processes to be included in the statespace model
771-
772-
Notes
773-
-----
774-
This function can be used to "inject" absolutely any type of dynamics you wish into a statespace model.
775-
Recall that a statespace model is a system of two matrix equations:
776-
777-
.. math::
778-
\begin{align} X_t &= c_t + T_t x_{t-1} + R_t \varepsilon_t & \varepsilon_t &\\sim N(0, Q_t) \\
779-
y_t &= d_t + Z_t x_t + \\eta_t & \\eta_t &\\sim N(0, H_t)
780-
\\end{align}
781-
782-
Any of the matrices :math:`c, d, T, Z, R, H, Q` can vary across time. When this function is invoked, the
783-
provided exogenous data is set as the observation intercept, :math:`d_t`. This makes the statespace model
784-
a model of the residuals :math:`y_t - d_t`. In fact, this is precisely the quantity that is used to compute
785-
the likelihood of the data during Kalman filtering.
786-
"""
787-
# User might pass a flat time-varying exog vector, need to make it a column
788-
if exog.ndim == 1:
789-
exog = pt.expand_dims(exog, -1)
790-
elif (exog.ndim == 2) and (exog.type.shape[-1] != 1):
791-
raise ValueError(
792-
f"If exogenous data is 2d, it must have a single column, found {exog.type.shape[-1]}"
793-
)
794-
elif exog.ndim > 2:
795-
raise ValueError(f"Exogenous data must be at most 2d, found {exog.ndim} dimensions")
796-
797-
# Need to specifically ask for the time dim (last one) when slicing into self.ssm
798-
d = self.ssm["obs_intercept", :, :]
799-
self.ssm["obs_intercept"] = d + exog
800-
801-
self._needs_exog_data = True
802-
self._exog_names.append(exog.name)
779+
for var, name in zip(states + covs, state_names + cov_names):
780+
dim_names = FILTER_OUTPUT_DIMS.get(name, None)
781+
dims = tuple([dim if dim in coords.keys() else None for dim in dim_names])
782+
pm.Deterministic(name, var, dims=dims)
803783

804784
def build_statespace_graph(
805785
self,
@@ -858,9 +838,11 @@ def build_statespace_graph(
858838
pm_mod = modelcontext(None)
859839

860840
self._insert_random_variables()
841+
self._save_exogenous_data_info()
861842
self._insert_data_variables()
862843

863844
obs_coords = pm_mod.coords.get(OBS_STATE_DIM, None)
845+
self._fit_data = data
864846

865847
data, nan_mask = register_data_with_pymc(
866848
data,
@@ -890,7 +872,7 @@ def build_statespace_graph(
890872
all_kf_outputs = states + [smooth_states] + covs + [smooth_covariances]
891873
self._register_kalman_filter_outputs_with_pymc_model(all_kf_outputs)
892874

893-
obs_dims = FILTER_OUTPUT_DIMS["obs"]
875+
obs_dims = FILTER_OUTPUT_DIMS["predicted_observed_state"]
894876
obs_dims = obs_dims if all([dim in pm_mod.coords.keys() for dim in obs_dims]) else None
895877

896878
SequenceMvNormal(
@@ -905,7 +887,6 @@ def build_statespace_graph(
905887
self._fit_coords = pm_mod.coords.copy()
906888
self._fit_dims = pm_mod.named_vars_to_dims.copy()
907889
self._fit_mode = mode
908-
self._fit_data = data
909890

910891
def _build_smoother_graph(
911892
self,
@@ -977,11 +958,21 @@ def _build_dummy_graph(self) -> None:
977958

978959
def _kalman_filter_outputs_from_dummy_graph(
979960
self,
961+
data: pt.TensorLike | None = None,
962+
data_dims: str | tuple[str] | list[str] | None = None,
980963
) -> tuple[list[pt.TensorVariable], list[tuple[pt.TensorVariable, pt.TensorVariable]]]:
981964
"""
982965
Builds a Kalman filter graph using "dummy" pm.Flat distributions for the model variables and sorts the returns
983966
into (mean, covariance) pairs for each of filtered, predicted, and smoothed output.
984967
968+
Parameters
969+
----------
970+
data: pt.TensorLike, optional
971+
Observed data on which to condition the model. If not provided, the function will use the data that was
972+
provided when the model was built.
973+
data_dims: str or tuple of str, optional
974+
Dimension names associated with the model data. If None, defaults to ("time", "obs_state")
975+
985976
Returns
986977
-------
987978
matrices: list of tensors
@@ -990,11 +981,33 @@ def _kalman_filter_outputs_from_dummy_graph(
990981
grouped_outputs: list of tuple of tensors
991982
A list of tuples, each containing the mean and covariance of the filtered, predicted, and smoothed states.
992983
"""
984+
pm_mod = modelcontext(None)
993985
self._build_dummy_graph()
986+
self._insert_random_variables()
987+
988+
for name in self.data_names:
989+
if name not in pm_mod:
990+
pm.Data(**self._exog_data_info[name])
991+
992+
self._insert_data_variables()
993+
994994
x0, P0, c, d, T, Z, R, H, Q = self.unpack_statespace()
995995

996+
if data is None:
997+
data = self._fit_data
998+
999+
obs_coords = pm_mod.coords.get(OBS_STATE_DIM, None)
1000+
1001+
data, nan_mask = register_data_with_pymc(
1002+
data,
1003+
n_obs=self.ssm.k_endog,
1004+
obs_coords=obs_coords,
1005+
data_dims=data_dims,
1006+
register_data=True,
1007+
)
1008+
9961009
filter_outputs = self.kalman_filter.build_graph(
997-
pt.as_tensor_variable(self._fit_data),
1010+
data,
9981011
x0,
9991012
P0,
10001013
c,
@@ -1026,7 +1039,12 @@ def _kalman_filter_outputs_from_dummy_graph(
10261039
return [x0, P0, c, d, T, Z, R, H, Q], grouped_outputs
10271040

10281041
def _sample_conditional(
1029-
self, idata: InferenceData, group: str, random_seed: Optional[RandomState] = None, **kwargs
1042+
self,
1043+
idata: InferenceData,
1044+
group: str,
1045+
random_seed: RandomState | None = None,
1046+
data: pt.TensorLike | None = None,
1047+
**kwargs,
10301048
):
10311049
"""
10321050
Common functionality shared between `sample_conditional_prior` and `sample_conditional_posterior`. See those
@@ -1043,6 +1061,10 @@ def _sample_conditional(
10431061
random_seed : int, RandomState or Generator, optional
10441062
Seed for the random number generator.
10451063
1064+
data: pt.TensorLike, optional
1065+
Observed data on which to condition the model. If not provided, the function will use the data that was
1066+
provided when the model was built.
1067+
10461068
kwargs:
10471069
Additional keyword arguments are passed to pymc.sample_posterior_predictive
10481070
@@ -1052,11 +1074,13 @@ def _sample_conditional(
10521074
An Arviz InferenceData object containing sampled trajectories from the requested conditional distribution,
10531075
with data variables "filtered_{group}", "predicted_{group}", and "smoothed_{group}".
10541076
"""
1077+
if data is None and self._fit_data is None:
1078+
raise ValueError("No data provided to condition the model")
10551079

10561080
_verify_group(group)
10571081
group_idata = getattr(idata, group)
10581082

1059-
with pm.Model(coords=self._fit_coords):
1083+
with pm.Model(coords=self._fit_coords) as forward_model:
10601084
[
10611085
x0,
10621086
P0,
@@ -1067,7 +1091,7 @@ def _sample_conditional(
10671091
R,
10681092
H,
10691093
Q,
1070-
], grouped_outputs = self._kalman_filter_outputs_from_dummy_graph()
1094+
], grouped_outputs = self._kalman_filter_outputs_from_dummy_graph(data=data)
10711095

10721096
for name, (mu, cov) in zip(FILTER_OUTPUT_TYPES, grouped_outputs):
10731097
dummy_ll = pt.zeros_like(mu)
@@ -1102,6 +1126,13 @@ def _sample_conditional(
11021126
dims=obs_dims,
11031127
)
11041128

1129+
# TODO: Remove this after pm.Flat initial values are fixed
1130+
forward_model.rvs_to_initial_values = {
1131+
rv: None for rv in forward_model.rvs_to_initial_values.keys()
1132+
}
1133+
1134+
frozen_model = freeze_dims_and_data(forward_model)
1135+
with frozen_model:
11051136
idata_conditional = pm.sample_posterior_predictive(
11061137
group_idata,
11071138
var_names=[
@@ -1187,8 +1218,9 @@ def _sample_unconditional(
11871218
if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM]]):
11881219
dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM]
11891220

1190-
with pm.Model(coords=temp_coords if dims is not None else None):
1221+
with pm.Model(coords=temp_coords if dims is not None else None) as forward_model:
11911222
self._build_dummy_graph()
1223+
self._insert_random_variables()
11921224
matrices = [x0, P0, c, d, T, Z, R, H, Q] = self.unpack_statespace()
11931225

11941226
if not self.measurement_error:
@@ -1204,8 +1236,16 @@ def _sample_unconditional(
12041236
dims=dims,
12051237
mode=self._fit_mode,
12061238
sequence_names=self.kalman_filter.seq_names,
1239+
k_endog=self.k_endog,
12071240
)
12081241

1242+
# TODO: Remove this after pm.Flat has its initial_value fixed
1243+
forward_model.rvs_to_initial_values = {
1244+
rv: None for rv in forward_model.rvs_to_initial_values.keys()
1245+
}
1246+
frozen_model = freeze_dims_and_data(forward_model)
1247+
1248+
with frozen_model:
12091249
idata_unconditional = pm.sample_posterior_predictive(
12101250
group_idata,
12111251
var_names=[f"{group}_latent", f"{group}_observed"],
@@ -1494,7 +1534,7 @@ def forecast(
14941534
mu_dims = ["data_time", ALL_STATE_DIM]
14951535
cov_dims = ["data_time", ALL_STATE_DIM, ALL_STATE_AUX_DIM]
14961536

1497-
with pm.Model(coords=temp_coords):
1537+
with pm.Model(coords=temp_coords) as forecast_model:
14981538
[
14991539
x0,
15001540
P0,
@@ -1505,7 +1545,9 @@ def forecast(
15051545
R,
15061546
H,
15071547
Q,
1508-
], grouped_outputs = self._kalman_filter_outputs_from_dummy_graph()
1548+
], grouped_outputs = self._kalman_filter_outputs_from_dummy_graph(
1549+
data_dims=["data_time", OBS_STATE_DIM]
1550+
)
15091551
group_idx = FILTER_OUTPUT_TYPES.index(filter_output)
15101552

15111553
mu, cov = grouped_outputs[group_idx]
@@ -1532,8 +1574,15 @@ def forecast(
15321574
dims=dims,
15331575
mode=self._fit_mode,
15341576
sequence_names=self.kalman_filter.seq_names,
1577+
k_endog=self.k_endog,
15351578
)
15361579

1580+
forecast_model.rvs_to_initial_values = {
1581+
k: None for k in forecast_model.rvs_to_initial_values.keys()
1582+
}
1583+
frozen_model = freeze_dims_and_data(forecast_model)
1584+
1585+
with frozen_model:
15371586
idata_forecast = pm.sample_posterior_predictive(
15381587
idata,
15391588
var_names=["forecast_latent", "forecast_observed"],
@@ -1542,7 +1591,7 @@ def forecast(
15421591
**kwargs,
15431592
)
15441593

1545-
return idata_forecast.posterior_predictive
1594+
return idata_forecast.posterior_predictive
15461595

15471596
def impulse_response_function(
15481597
self,
@@ -1646,6 +1695,8 @@ def impulse_response_function(
16461695

16471696
with pm.Model(coords=simulation_coords):
16481697
self._build_dummy_graph()
1698+
self._insert_random_variables()
1699+
16491700
P0, _, c, d, T, Z, R, H, post_Q = self.unpack_statespace()
16501701
x0 = pm.Deterministic("x0_new", pt.zeros(self.k_states), dims=[ALL_STATE_DIM])
16511702

pymc_experimental/statespace/models/SARIMAX.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,8 +333,9 @@ def param_dims(self):
333333
"seasonal_ma_params": (SEASONAL_MA_PARAM_DIM,),
334334
}
335335
if self.k_endog == 1:
336-
del coord_map["sigma_state"]
337-
if not self.measurement_error or self.k_endog == 1:
336+
coord_map["sigma_state"] = ()
337+
coord_map["sigma_obs"] = ()
338+
if not self.measurement_error:
338339
del coord_map["sigma_obs"]
339340
if self.p == 0:
340341
del coord_map["ar_params"]

pymc_experimental/statespace/utils/constants.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
]
4545

4646
SMOOTHER_OUTPUT_NAMES = ["smoothed_state", "smoothed_covariance"]
47-
OBSERVED_OUTPUT_NAMES = ["observed_state", "observed_covariance"]
47+
OBSERVED_OUTPUT_NAMES = ["predicted_observed_state", "predicted_observed_covariance"]
4848

4949
MATRIX_DIMS = {
5050
"x0": (ALL_STATE_DIM,),
@@ -65,7 +65,8 @@
6565
"filtered_covariance": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
6666
"smoothed_covariance": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
6767
"predicted_covariance": (TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM),
68-
"obs": (TIME_DIM, OBS_STATE_DIM),
68+
"predicted_observed_state": (TIME_DIM, OBS_STATE_DIM),
69+
"predicted_observed_covariance": (TIME_DIM, OBS_STATE_DIM, OBS_STATE_AUX_DIM),
6970
}
7071

7172
POSITION_DERIVATIVE_NAMES = ["level", "trend", "acceleration", "jerk", "snap", "crackle", "pop"]

0 commit comments

Comments
 (0)