Skip to content

Commit 09a2dc8

Browse files
Draft example notebook
1 parent d18f594 commit 09a2dc8

File tree

6 files changed

+6
-8
lines changed

6 files changed

+6
-8
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
from pymc_experimental.statespace.core.compile import compile_statespace
12
from pymc_experimental.statespace.models import structural
23
from pymc_experimental.statespace.models.ETS import BayesianETS
34
from pymc_experimental.statespace.models.SARIMAX import BayesianSARIMA
45
from pymc_experimental.statespace.models.VARMAX import BayesianVARMAX
5-
from pymc_experimental.statespace.utils import compile_statespace
66

77
__all__ = ["structural", "BayesianSARIMA", "BayesianVARMAX", "BayesianETS", "compile_statespace"]
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1+
from pymc_experimental.statespace.core.compile import compile_statespace
12
from pymc_experimental.statespace.core.representation import PytensorRepresentation
23
from pymc_experimental.statespace.core.statespace import PyMCStateSpace
34

4-
__all__ = ["PytensorRepresentation", "PyMCStateSpace"]
5+
__all__ = ["PytensorRepresentation", "PyMCStateSpace", "compile_statespace"]

pymc_experimental/statespace/core/statespace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ def add_default_priors(self) -> None:
461461

462462
def make_and_register_variable(
463463
self, name, shape: int | tuple[int, ...] | None = None, dtype=floatX
464-
) -> Variable:
464+
) -> pt.TensorVariable:
465465
"""
466466
Helper function to create a pytensor symbolic variable and register it in the _name_to_variable dictionary
467467

pymc_experimental/statespace/models/SARIMAX.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,8 +334,8 @@ def param_dims(self):
334334
"seasonal_ma_params": (SEASONAL_MA_PARAM_DIM,),
335335
}
336336
if self.k_endog == 1:
337-
coord_map["sigma_state"] = ()
338-
coord_map["sigma_obs"] = ()
337+
coord_map["sigma_state"] = None
338+
coord_map["sigma_obs"] = None
339339
if not self.measurement_error:
340340
del coord_map["sigma_obs"]
341341
if self.p == 0:
Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +0,0 @@
1-
from pymc_experimental.statespace.utils.compile import compile_statespace
2-
3-
__all__ = ["compile_statespace"]

0 commit comments

Comments
 (0)