Skip to content

Commit 960d682

Browse files
_insert_random_variables infers shape from coordinates
1 parent beaa56c commit 960d682

File tree

1 file changed

+8
-67
lines changed

1 file changed

+8
-67
lines changed

pymc_experimental/statespace/core/statespace.py

Lines changed: 8 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -284,14 +284,16 @@ def _print_data_requirements(self) -> None:
284284
Prints a short report to the terminal about the data needed for the model, including their names, shapes,
285285
and named dimensions.
286286
"""
287+
if not self.data_info:
288+
return
287289

288290
out = ""
289291
for data, info in self.data_info.items():
290292
out += f'\t{data} -- shape: {info["shape"]}, dims: {info["dims"]}\n'
291293
out = out.rstrip()
292294

293295
_log.info(
294-
"The following MutableData variables should be assigned to the model inside a PyMC "
296+
"The following Data variables should be assigned to the model inside a PyMC "
295297
f"model block: \n"
296298
f"{out}"
297299
)
@@ -618,63 +620,6 @@ def _get_matrix_shape_and_dims(
618620

619621
return shape, dims
620622

621-
def _get_output_shape_and_dims(
622-
self, idata: InferenceData, filter_output: str
623-
) -> tuple[
624-
Optional[tuple[int]], Optional[tuple[int]], Optional[tuple[str]], Optional[tuple[str]]
625-
]:
626-
"""
627-
Get the shapes and dimensions of the output variables from the provided InferenceData.
628-
629-
This method extracts the shapes and dimensions of the output variables representing the state estimates
630-
and covariances from the provided ArviZ InferenceData object. The state estimates are obtained from the
631-
specified `filter_output` mode, which can be one of "filtered", "predicted", or "smoothed".
632-
633-
Parameters
634-
----------
635-
idata : arviz.InferenceData
636-
The ArviZ InferenceData object containing the posterior samples.
637-
638-
filter_output : str
639-
The name of the filter output whose shape is being checked. It can be one of "filtered",
640-
"predicted", or "smoothed".
641-
642-
Returns
643-
-------
644-
mu_shape: tuple(int, int) or None
645-
Shape of the mean vectors returned by the Kalman filter. Should be (n_data_timesteps, k_states).
646-
If named dimensions are found, this will be None.
647-
648-
cov_shape: tuple (int, int, int) or None
649-
Shape of the hidden state covariance matrices returned by the Kalman filter. Should be
650-
(n_data_timesteps, k_states, k_states).
651-
If named dimensions are found, this will be None.
652-
653-
mu_dims: tuple(str, str) or None
654-
*Default* named dimensions associated with the mean vectors returned by the Kalman filter, or None if
655-
the default names are not found.
656-
657-
cov_dims: tuple (str, str, str) or None
658-
*Default* named dimensions associated with the covariance matrices returned by the Kalman filter, or None if
659-
the default names are not found.
660-
"""
661-
662-
mu_dims = None
663-
cov_dims = None
664-
665-
mu_shape = idata[f"{filter_output}_state"].values.shape[2:]
666-
cov_shape = idata[f"{filter_output}_covariance"].values.shape[2:]
667-
668-
if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM]]):
669-
time_dim = TIME_DIM
670-
mu_dims = [time_dim, ALL_STATE_DIM]
671-
cov_dims = [time_dim, ALL_STATE_DIM, ALL_STATE_AUX_DIM]
672-
673-
mu_shape = None
674-
cov_shape = None
675-
676-
return mu_shape, cov_shape, mu_dims, cov_dims
677-
678623
def _insert_random_variables(self):
679624
"""
680625
Replace pytensor symbolic variables with PyMC random variables.
@@ -1506,11 +1451,11 @@ def forecast(
15061451
"Scenario-based forcasting with exogenous variables not currently supported"
15071452
)
15081453

1509-
dims = None
15101454
temp_coords = self._fit_coords.copy()
15111455

15121456
filter_time_dim = TIME_DIM
15131457

1458+
dims = None
15141459
if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]):
15151460
dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM]
15161461

@@ -1544,14 +1489,10 @@ def forecast(
15441489
temp_coords["data_time"] = time_index
15451490
temp_coords[TIME_DIM] = forecast_index
15461491

1547-
mu_shape, cov_shape, mu_dims, cov_dims = self._get_output_shape_and_dims(
1548-
idata.posterior, filter_output
1549-
)
1550-
1551-
if mu_dims is not None:
1552-
mu_dims = ["data_time"] + mu_dims[1:]
1553-
if cov_dims is not None:
1554-
cov_dims = ["data_time"] + cov_dims[1:]
1492+
mu_dims, cov_dims = None, None
1493+
if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM]]):
1494+
mu_dims = ["data_time", ALL_STATE_DIM]
1495+
cov_dims = ["data_time", ALL_STATE_DIM, ALL_STATE_AUX_DIM]
15551496

15561497
with pm.Model(coords=temp_coords):
15571498
[

0 commit comments

Comments
 (0)