Skip to content

Commit 50e71ba

Browse files
Add innovation hidden state, data is observed state
1 parent 292dbdd commit 50e71ba

File tree

2 files changed

+27
-23
lines changed

2 files changed

+27
-23
lines changed

pymc_experimental/statespace/models/ETS.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def param_info(self) -> dict[str, dict[str, Any]]:
272272

273273
@property
274274
def state_names(self):
275-
states = ["data", "level"]
275+
states = ["innovation", "level"]
276276
if self.trend:
277277
states += ["trend"]
278278
if self.seasonal:
@@ -282,7 +282,7 @@ def state_names(self):
282282

283283
@property
284284
def observed_states(self):
285-
return [self.state_names[0]]
285+
return ["data"]
286286

287287
@property
288288
def shock_names(self):
@@ -326,20 +326,15 @@ def make_symbolic_graph(self) -> None:
326326
self.ssm["initial_state", :] = x0
327327
self.ssm["initial_state_cov"] = P0
328328

329-
# Only the first state is ever observed
330-
Z = np.zeros((self.k_endog, self.k_states))
331-
Z[0, 0] = 1
332-
self.ssm["design"] = Z
333-
334329
# The shape of R can be pre-allocated, then filled with the required parameters
335330
R = pt.zeros((self.k_states, self.k_posdef))
336331
R = pt.set_subtensor(R[0, :], 1.0) # We will always have y_t = ... + e_t
337332

338333
alpha = self.make_and_register_variable("alpha", shape=(), dtype=floatX)
339334
R = pt.set_subtensor(R[1, 0], alpha) # and l_t = ... + alpha * e_t
340335

341-
# Data and level component always exists, the base case is y_t = l_{t-1} and l_t = l_{t-1}
342-
T_base = pt.as_tensor_variable(np.array([[0.0, 1.0], [0.0, 1.0]]))
336+
# Shock and level component always exists, the base case is e_t = e_t and l_t = l_{t-1}
337+
T_base = pt.as_tensor_variable(np.array([[0.0, 0.0], [0.0, 1.0]]))
343338

344339
if self.trend:
345340
beta = self.make_and_register_variable("beta", shape=(), dtype=floatX)
@@ -349,7 +344,7 @@ def make_symbolic_graph(self) -> None:
349344
# y_t = l_{t-1} + b_{t-1}
350345
# l_t = l_{t-1} + b_{t-1}
351346
# b_t = b_{t-1}
352-
T_base = pt.as_tensor_variable(([0.0, 1.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]))
347+
T_base = pt.as_tensor_variable(([0.0, 0.0, 0.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]))
353348

354349
if self.damped_trend:
355350
phi = self.make_and_register_variable("phi", shape=(), dtype=floatX)
@@ -358,7 +353,7 @@ def make_symbolic_graph(self) -> None:
358353
# y_t = l_{t-1} + phi * b_{t-1}
359354
# l_t = l_{t-1} + phi * b_{t-1}
360355
# b_t = phi * b_{t-1}
361-
T_base = pt.set_subtensor(T_base[:, 2], phi)
356+
T_base = pt.set_subtensor(T_base[1:, 2], phi)
362357

363358
T_components = [T_base]
364359

@@ -375,10 +370,17 @@ def make_symbolic_graph(self) -> None:
375370
self.ssm["selection"] = R
376371

377372
T = pt.linalg.block_diag(*T_components)
378-
if self.seasonal:
379-
T = pt.set_subtensor(T[0, 2 + int(self.trend) + 1], 1.0)
380373
self.ssm["transition"] = pt.specify_shape(T, (self.k_states, self.k_states))
381374

375+
Z = np.zeros((self.k_endog, self.k_states))
376+
Z[0, 0] = 1.0 # innovation
377+
Z[0, 1] = 1.0 # level
378+
if self.trend:
379+
Z[0, 2] = 1.0
380+
if self.seasonal:
381+
Z[0, 2 + int(self.trend)] = 1.0
382+
self.ssm["design"] = Z
383+
382384
# Set up the state covariance matrix
383385
state_cov_idx = ("state_cov",) + np.diag_indices(self.k_posdef)
384386
state_cov = self.make_and_register_variable(

tests/statespace/test_ETS.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -153,38 +153,40 @@ def test_statespace_matrices(order: tuple[str, str, str], expected_params: list[
153153
assert_allclose(H, np.eye(1) * test_values["sigma_obs"] ** 2)
154154
assert_allclose(Q, np.eye(1) * test_values["sigma_state"] ** 2)
155155

156-
Z_val = np.zeros((1, expected_states))
157-
Z_val[0, 0] = 1.0
158-
assert_allclose(Z, Z_val)
159-
160156
R_val = np.zeros((expected_states, 1))
161157
R_val[0] = 1.0
162158
R_val[1] = test_values["alpha"]
163159

160+
Z_val = np.zeros((1, expected_states))
161+
Z_val[0, 0] = 1.0
162+
Z_val[0, 1] = 1.0
163+
164164
if order[1] == "N":
165-
T_val = np.array([[0.0, 1.0], [0.0, 1.0]])
165+
T_val = np.array([[0.0, 0.0], [0.0, 1.0]])
166166
else:
167167
R_val[2] = test_values["beta"]
168-
T_val = np.array([[0.0, 1.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]])
168+
T_val = np.array([[0.0, 0.0, 0.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]])
169+
Z_val[0, 2] = 1.0
169170

170171
if order[1] == "Ad":
171-
T_val[:, -1] *= test_values["phi"]
172+
T_val[1:, -1] *= test_values["phi"]
172173

173174
if order[2] == "A":
174175
R_val[3] = test_values["gamma"]
175176
S = np.eye(seasonal_periods, k=-1)
176177
S[0, :] = -1
178+
Z_val[0, 2 + int(order[1] != "N")] = 1.0
177179
else:
178180
S = np.eye(0)
181+
179182
T_val = linalg.block_diag(T_val, S)
180-
if order[2] != "N":
181-
T_val[0, 3 + int(order[1] != "N")] = 1.0
182183

183184
assert_allclose(T, T_val)
184185
assert_allclose(R, R_val)
186+
assert_allclose(Z, Z_val)
185187

186188

187-
def test_simulate_model():
189+
def test_deterministic_simulation_matches_statsmodels():
188190
mod = BayesianETS(order=("A", "Ad", "A"), seasonal_periods=4, measurement_error=False)
189191

190192
rng = np.random.default_rng()

0 commit comments

Comments
 (0)