8
8
import pytensor .tensor as pt
9
9
from arviz import InferenceData
10
10
from pymc .model import modelcontext
11
+ from pymc .model .transform .optimization import freeze_dims_and_data
11
12
from pymc .util import RandomState
12
13
from pytensor import Variable , graph_replace
13
14
from pytensor .compile import get_mode
@@ -223,8 +224,9 @@ def __init__(
223
224
self ._fit_dims : Optional [dict [str , Sequence [str ]]] = None
224
225
self ._fit_data : Optional [pt .TensorVariable ] = None
225
226
226
- self ._needs_exog_data = False
227
+ self ._needs_exog_data = None
227
228
self ._exog_names = []
229
+ self ._exog_data_info = {}
228
230
self ._name_to_variable = {}
229
231
self ._name_to_data = {}
230
232
@@ -348,7 +350,7 @@ def data_names(self) -> list[str]:
348
350
This does not include the observed data series, which is automatically handled by PyMC. This property only
349
351
needs to be implemented for models that expect exogenous data.
350
352
"""
351
- raise NotImplementedError ( "The data_names property has not been implemented!" )
353
+ return []
352
354
353
355
@property
354
356
def param_info (self ) -> dict [str , dict [str , Any ]]:
@@ -620,6 +622,19 @@ def _get_matrix_shape_and_dims(
620
622
621
623
return shape , dims
622
624
625
+ def _save_exogenous_data_info (self ):
626
+ """
627
+ Store exogenous data required by posterior sampling functions
628
+ """
629
+ pymc_mod = modelcontext (None )
630
+ for data_name in self .data_names :
631
+ data = pymc_mod [data_name ]
632
+ self ._exog_data_info [data_name ] = {
633
+ "name" : data_name ,
634
+ "value" : data .get_value (),
635
+ "dims" : pymc_mod .named_vars_to_dims .get (data_name , None ),
636
+ }
637
+
623
638
def _insert_random_variables (self ):
624
639
"""
625
640
Replace pytensor symbolic variables with PyMC random variables.
@@ -743,63 +758,28 @@ def _register_matrices_with_pymc_model(self) -> list[pt.TensorVariable]:
743
758
@staticmethod
744
759
def _register_kalman_filter_outputs_with_pymc_model (outputs : tuple [pt .TensorVariable ]) -> None :
745
760
mod = modelcontext (None )
761
+ coords = mod .coords
762
+
746
763
states , covs = outputs [:4 ], outputs [4 :]
747
764
748
- state_names = ["filtered_state" , "predicted_state" , "observed_state" , "smoothed_state" ]
765
+ state_names = [
766
+ "filtered_state" ,
767
+ "predicted_state" ,
768
+ "predicted_observed_state" ,
769
+ "smoothed_state" ,
770
+ ]
749
771
cov_names = [
750
772
"filtered_covariance" ,
751
773
"predicted_covariance" ,
752
- "observed_covariance " ,
774
+ "predicted_observed_covariance " ,
753
775
"smoothed_covariance" ,
754
776
]
755
777
756
778
with mod :
757
- for state , name in zip (states , state_names ):
758
- pm .Deterministic (name , state , dims = FILTER_OUTPUT_DIMS .get (name , None ))
759
-
760
- for cov , name in zip (covs , cov_names ):
761
- pm .Deterministic (name , cov , dims = FILTER_OUTPUT_DIMS .get (name , None ))
762
-
763
- def add_exogenous (self , exog : pt .TensorVariable ) -> None :
764
- """
765
- Add an exogenous process to the statespace model
766
-
767
- Parameters
768
- ----------
769
- exog: TensorVariable
770
- An (N, k_endog) tensor representing exogenous processes to be included in the statespace model
771
-
772
- Notes
773
- -----
774
- This function can be used to "inject" absolutely any type of dynamics you wish into a statespace model.
775
- Recall that a statespace model is a system of two matrix equations:
776
-
777
- .. math::
778
- \b egin{align} X_t &= c_t + T_t x_{t-1} + R_t \v arepsilon_t & \v arepsilon_t &\\ sim N(0, Q_t) \\
779
- y_t &= d_t + Z_t x_t + \\ eta_t & \\ eta_t &\\ sim N(0, H_t)
780
- \\ end{align}
781
-
782
- Any of the matrices :math:`c, d, T, Z, R, H, Q` can vary across time. When this function is invoked, the
783
- provided exogenous data is set as the observation intercept, :math:`d_t`. This makes the statespace model
784
- a model of the residuals :math:`y_t - d_t`. In fact, this is precisely the quantity that is used to compute
785
- the likelihood of the data during Kalman filtering.
786
- """
787
- # User might pass a flat time-varying exog vector, need to make it a column
788
- if exog .ndim == 1 :
789
- exog = pt .expand_dims (exog , - 1 )
790
- elif (exog .ndim == 2 ) and (exog .type .shape [- 1 ] != 1 ):
791
- raise ValueError (
792
- f"If exogenous data is 2d, it must have a single column, found { exog .type .shape [- 1 ]} "
793
- )
794
- elif exog .ndim > 2 :
795
- raise ValueError (f"Exogenous data must be at most 2d, found { exog .ndim } dimensions" )
796
-
797
- # Need to specifically ask for the time dim (last one) when slicing into self.ssm
798
- d = self .ssm ["obs_intercept" , :, :]
799
- self .ssm ["obs_intercept" ] = d + exog
800
-
801
- self ._needs_exog_data = True
802
- self ._exog_names .append (exog .name )
779
+ for var , name in zip (states + covs , state_names + cov_names ):
780
+ dim_names = FILTER_OUTPUT_DIMS .get (name , None )
781
+ dims = tuple ([dim if dim in coords .keys () else None for dim in dim_names ])
782
+ pm .Deterministic (name , var , dims = dims )
803
783
804
784
def build_statespace_graph (
805
785
self ,
@@ -858,9 +838,11 @@ def build_statespace_graph(
858
838
pm_mod = modelcontext (None )
859
839
860
840
self ._insert_random_variables ()
841
+ self ._save_exogenous_data_info ()
861
842
self ._insert_data_variables ()
862
843
863
844
obs_coords = pm_mod .coords .get (OBS_STATE_DIM , None )
845
+ self ._fit_data = data
864
846
865
847
data , nan_mask = register_data_with_pymc (
866
848
data ,
@@ -890,7 +872,7 @@ def build_statespace_graph(
890
872
all_kf_outputs = states + [smooth_states ] + covs + [smooth_covariances ]
891
873
self ._register_kalman_filter_outputs_with_pymc_model (all_kf_outputs )
892
874
893
- obs_dims = FILTER_OUTPUT_DIMS ["obs " ]
875
+ obs_dims = FILTER_OUTPUT_DIMS ["predicted_observed_state " ]
894
876
obs_dims = obs_dims if all ([dim in pm_mod .coords .keys () for dim in obs_dims ]) else None
895
877
896
878
SequenceMvNormal (
@@ -905,7 +887,6 @@ def build_statespace_graph(
905
887
self ._fit_coords = pm_mod .coords .copy ()
906
888
self ._fit_dims = pm_mod .named_vars_to_dims .copy ()
907
889
self ._fit_mode = mode
908
- self ._fit_data = data
909
890
910
891
def _build_smoother_graph (
911
892
self ,
@@ -977,11 +958,21 @@ def _build_dummy_graph(self) -> None:
977
958
978
959
def _kalman_filter_outputs_from_dummy_graph (
979
960
self ,
961
+ data : pt .TensorLike | None = None ,
962
+ data_dims : str | tuple [str ] | list [str ] | None = None ,
980
963
) -> tuple [list [pt .TensorVariable ], list [tuple [pt .TensorVariable , pt .TensorVariable ]]]:
981
964
"""
982
965
Builds a Kalman filter graph using "dummy" pm.Flat distributions for the model variables and sorts the returns
983
966
into (mean, covariance) pairs for each of filtered, predicted, and smoothed output.
984
967
968
+ Parameters
969
+ ----------
970
+ data: pt.TensorLike, optional
971
+ Observed data on which to condition the model. If not provided, the function will use the data that was
972
+ provided when the model was built.
973
+ data_dims: str or tuple of str, optional
974
+ Dimension names associated with the model data. If None, defaults to ("time", "obs_state")
975
+
985
976
Returns
986
977
-------
987
978
matrices: list of tensors
@@ -990,11 +981,33 @@ def _kalman_filter_outputs_from_dummy_graph(
990
981
grouped_outputs: list of tuple of tensors
991
982
A list of tuples, each containing the mean and covariance of the filtered, predicted, and smoothed states.
992
983
"""
984
+ pm_mod = modelcontext (None )
993
985
self ._build_dummy_graph ()
986
+ self ._insert_random_variables ()
987
+
988
+ for name in self .data_names :
989
+ if name not in pm_mod :
990
+ pm .Data (** self ._exog_data_info [name ])
991
+
992
+ self ._insert_data_variables ()
993
+
994
994
x0 , P0 , c , d , T , Z , R , H , Q = self .unpack_statespace ()
995
995
996
+ if data is None :
997
+ data = self ._fit_data
998
+
999
+ obs_coords = pm_mod .coords .get (OBS_STATE_DIM , None )
1000
+
1001
+ data , nan_mask = register_data_with_pymc (
1002
+ data ,
1003
+ n_obs = self .ssm .k_endog ,
1004
+ obs_coords = obs_coords ,
1005
+ data_dims = data_dims ,
1006
+ register_data = True ,
1007
+ )
1008
+
996
1009
filter_outputs = self .kalman_filter .build_graph (
997
- pt . as_tensor_variable ( self . _fit_data ) ,
1010
+ data ,
998
1011
x0 ,
999
1012
P0 ,
1000
1013
c ,
@@ -1026,7 +1039,12 @@ def _kalman_filter_outputs_from_dummy_graph(
1026
1039
return [x0 , P0 , c , d , T , Z , R , H , Q ], grouped_outputs
1027
1040
1028
1041
def _sample_conditional (
1029
- self , idata : InferenceData , group : str , random_seed : Optional [RandomState ] = None , ** kwargs
1042
+ self ,
1043
+ idata : InferenceData ,
1044
+ group : str ,
1045
+ random_seed : RandomState | None = None ,
1046
+ data : pt .TensorLike | None = None ,
1047
+ ** kwargs ,
1030
1048
):
1031
1049
"""
1032
1050
Common functionality shared between `sample_conditional_prior` and `sample_conditional_posterior`. See those
@@ -1043,6 +1061,10 @@ def _sample_conditional(
1043
1061
random_seed : int, RandomState or Generator, optional
1044
1062
Seed for the random number generator.
1045
1063
1064
+ data: pt.TensorLike, optional
1065
+ Observed data on which to condition the model. If not provided, the function will use the data that was
1066
+ provided when the model was built.
1067
+
1046
1068
kwargs:
1047
1069
Additional keyword arguments are passed to pymc.sample_posterior_predictive
1048
1070
@@ -1052,11 +1074,13 @@ def _sample_conditional(
1052
1074
An Arviz InferenceData object containing sampled trajectories from the requested conditional distribution,
1053
1075
with data variables "filtered_{group}", "predicted_{group}", and "smoothed_{group}".
1054
1076
"""
1077
+ if data is None and self ._fit_data is None :
1078
+ raise ValueError ("No data provided to condition the model" )
1055
1079
1056
1080
_verify_group (group )
1057
1081
group_idata = getattr (idata , group )
1058
1082
1059
- with pm .Model (coords = self ._fit_coords ):
1083
+ with pm .Model (coords = self ._fit_coords ) as forward_model :
1060
1084
[
1061
1085
x0 ,
1062
1086
P0 ,
@@ -1067,7 +1091,7 @@ def _sample_conditional(
1067
1091
R ,
1068
1092
H ,
1069
1093
Q ,
1070
- ], grouped_outputs = self ._kalman_filter_outputs_from_dummy_graph ()
1094
+ ], grouped_outputs = self ._kalman_filter_outputs_from_dummy_graph (data = data )
1071
1095
1072
1096
for name , (mu , cov ) in zip (FILTER_OUTPUT_TYPES , grouped_outputs ):
1073
1097
dummy_ll = pt .zeros_like (mu )
@@ -1102,6 +1126,13 @@ def _sample_conditional(
1102
1126
dims = obs_dims ,
1103
1127
)
1104
1128
1129
+ # TODO: Remove this after pm.Flat initial values are fixed
1130
+ forward_model .rvs_to_initial_values = {
1131
+ rv : None for rv in forward_model .rvs_to_initial_values .keys ()
1132
+ }
1133
+
1134
+ frozen_model = freeze_dims_and_data (forward_model )
1135
+ with frozen_model :
1105
1136
idata_conditional = pm .sample_posterior_predictive (
1106
1137
group_idata ,
1107
1138
var_names = [
@@ -1187,8 +1218,9 @@ def _sample_unconditional(
1187
1218
if all ([dim in self ._fit_coords for dim in [TIME_DIM , ALL_STATE_DIM , OBS_STATE_DIM ]]):
1188
1219
dims = [TIME_DIM , ALL_STATE_DIM , OBS_STATE_DIM ]
1189
1220
1190
- with pm .Model (coords = temp_coords if dims is not None else None ):
1221
+ with pm .Model (coords = temp_coords if dims is not None else None ) as forward_model :
1191
1222
self ._build_dummy_graph ()
1223
+ self ._insert_random_variables ()
1192
1224
matrices = [x0 , P0 , c , d , T , Z , R , H , Q ] = self .unpack_statespace ()
1193
1225
1194
1226
if not self .measurement_error :
@@ -1204,8 +1236,16 @@ def _sample_unconditional(
1204
1236
dims = dims ,
1205
1237
mode = self ._fit_mode ,
1206
1238
sequence_names = self .kalman_filter .seq_names ,
1239
+ k_endog = self .k_endog ,
1207
1240
)
1208
1241
1242
+ # TODO: Remove this after pm.Flat has its initial_value fixed
1243
+ forward_model .rvs_to_initial_values = {
1244
+ rv : None for rv in forward_model .rvs_to_initial_values .keys ()
1245
+ }
1246
+ frozen_model = freeze_dims_and_data (forward_model )
1247
+
1248
+ with frozen_model :
1209
1249
idata_unconditional = pm .sample_posterior_predictive (
1210
1250
group_idata ,
1211
1251
var_names = [f"{ group } _latent" , f"{ group } _observed" ],
@@ -1494,7 +1534,7 @@ def forecast(
1494
1534
mu_dims = ["data_time" , ALL_STATE_DIM ]
1495
1535
cov_dims = ["data_time" , ALL_STATE_DIM , ALL_STATE_AUX_DIM ]
1496
1536
1497
- with pm .Model (coords = temp_coords ):
1537
+ with pm .Model (coords = temp_coords ) as forecast_model :
1498
1538
[
1499
1539
x0 ,
1500
1540
P0 ,
@@ -1505,7 +1545,9 @@ def forecast(
1505
1545
R ,
1506
1546
H ,
1507
1547
Q ,
1508
- ], grouped_outputs = self ._kalman_filter_outputs_from_dummy_graph ()
1548
+ ], grouped_outputs = self ._kalman_filter_outputs_from_dummy_graph (
1549
+ data_dims = ["data_time" , OBS_STATE_DIM ]
1550
+ )
1509
1551
group_idx = FILTER_OUTPUT_TYPES .index (filter_output )
1510
1552
1511
1553
mu , cov = grouped_outputs [group_idx ]
@@ -1532,8 +1574,15 @@ def forecast(
1532
1574
dims = dims ,
1533
1575
mode = self ._fit_mode ,
1534
1576
sequence_names = self .kalman_filter .seq_names ,
1577
+ k_endog = self .k_endog ,
1535
1578
)
1536
1579
1580
+ forecast_model .rvs_to_initial_values = {
1581
+ k : None for k in forecast_model .rvs_to_initial_values .keys ()
1582
+ }
1583
+ frozen_model = freeze_dims_and_data (forecast_model )
1584
+
1585
+ with frozen_model :
1537
1586
idata_forecast = pm .sample_posterior_predictive (
1538
1587
idata ,
1539
1588
var_names = ["forecast_latent" , "forecast_observed" ],
@@ -1542,7 +1591,7 @@ def forecast(
1542
1591
** kwargs ,
1543
1592
)
1544
1593
1545
- return idata_forecast .posterior_predictive
1594
+ return idata_forecast .posterior_predictive
1546
1595
1547
1596
def impulse_response_function (
1548
1597
self ,
@@ -1646,6 +1695,8 @@ def impulse_response_function(
1646
1695
1647
1696
with pm .Model (coords = simulation_coords ):
1648
1697
self ._build_dummy_graph ()
1698
+ self ._insert_random_variables ()
1699
+
1649
1700
P0 , _ , c , d , T , Z , R , H , post_Q = self .unpack_statespace ()
1650
1701
x0 = pm .Deterministic ("x0_new" , pt .zeros (self .k_states ), dims = [ALL_STATE_DIM ])
1651
1702
0 commit comments