Skip to content

Commit 3a48e39

Browse files
Allow mutlivariate ETS models
1 parent 33700a4 commit 3a48e39

File tree

2 files changed

+288
-52
lines changed

2 files changed

+288
-52
lines changed

pymc_experimental/statespace/models/ETS.py

Lines changed: 176 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
ALL_STATE_AUX_DIM,
1111
ALL_STATE_DIM,
1212
ETS_SEASONAL_DIM,
13+
OBS_STATE_AUX_DIM,
1314
OBS_STATE_DIM,
1415
)
1516

@@ -176,12 +177,15 @@ class BayesianETS(PyMCStateSpace):
176177
def __init__(
177178
self,
178179
order: tuple[str, str, str] | None = None,
180+
endog_names: str | list[str] | None = None,
181+
k_endog: int = 1,
179182
trend: bool = True,
180183
damped_trend: bool = False,
181184
seasonal: bool = False,
182185
seasonal_periods: int | None = None,
183186
measurement_error: bool = False,
184187
use_transformed_parameterization: bool = False,
188+
dense_innovation_covariance: bool = False,
185189
filter_type: str = "standard",
186190
verbose: bool = True,
187191
):
@@ -214,13 +218,26 @@ def __init__(
214218
if self.seasonal and self.seasonal_periods is None:
215219
raise ValueError("If seasonal is True, seasonal_periods must be provided.")
216220

221+
if endog_names is not None:
222+
endog_names = list(endog_names)
223+
k_endog = len(endog_names)
224+
else:
225+
endog_names = [f"data_{i}" for i in range(k_endog)] if k_endog > 1 else ["data"]
226+
227+
self.endog_names = endog_names
228+
229+
if dense_innovation_covariance and k_endog == 1:
230+
dense_innovation_covariance = False
231+
232+
self.dense_innovation_covariance = dense_innovation_covariance
233+
217234
k_states = (
218235
2
219236
+ int(trend)
220237
+ int(seasonal) * (seasonal_periods if seasonal_periods is not None else 0)
221-
)
222-
k_posdef = 1
223-
k_endog = 1
238+
) * k_endog
239+
240+
k_posdef = k_endog
224241

225242
super().__init__(
226243
k_endog,
@@ -243,6 +260,7 @@ def param_names(self):
243260
"gamma",
244261
"phi",
245262
"sigma_state",
263+
"state_cov",
246264
"sigma_obs",
247265
]
248266
if not self.trend:
@@ -256,6 +274,11 @@ def param_names(self):
256274
if not self.measurement_error:
257275
names.remove("sigma_obs")
258276

277+
if self.dense_innovation_covariance:
278+
names.remove("sigma_state")
279+
else:
280+
names.remove("state_cov")
281+
259282
return names
260283

261284
@property
@@ -283,27 +306,34 @@ def param_info(self) -> dict[str, dict[str, Any]]:
283306
"constraints": "Positive",
284307
},
285308
"alpha": {
286-
"shape": None,
309+
"shape": None if self.k_endog == 1 else (self.k_endog,),
287310
"constraints": "0 < alpha < 1",
288311
},
289312
"beta": {
290-
"shape": None,
313+
"shape": None if self.k_endog == 1 else (self.k_endog,),
291314
"constraints": "0 < beta < 1"
292315
if not self.use_transformed_parameterization
293316
else "0 < beta < alpha",
294317
},
295318
"gamma": {
296-
"shape": None,
319+
"shape": None if self.k_endog == 1 else (self.k_endog,),
297320
"constraints": "0 < gamma< 1"
298321
if not self.use_transformed_parameterization
299322
else "0 < gamma < (1 - alpha)",
300323
},
301324
"phi": {
302-
"shape": None,
325+
"shape": None if self.k_endog == 1 else (self.k_endog,),
303326
"constraints": "0 < phi < 1",
304327
},
305328
}
306329

330+
if self.dense_innovation_covariance:
331+
del info["sigma_state"]
332+
info["state_cov"] = {
333+
"shape": (self.k_posdef, self.k_posdef),
334+
"constraints": "Positive Semi-definite",
335+
}
336+
307337
for name in self.param_names:
308338
info[name]["dims"] = self.param_dims.get(name, None)
309339

@@ -317,15 +347,22 @@ def state_names(self):
317347
if self.seasonal:
318348
states += [f"L{i}.season" for i in range(1, self.seasonal_periods + 1)]
319349

350+
if self.k_endog > 1:
351+
states = [f"{name}_{state}" for name in self.endog_names for state in states]
352+
320353
return states
321354

322355
@property
323356
def observed_states(self):
324-
return ["data"]
357+
return self.endog_names
325358

326359
@property
327360
def shock_names(self):
328-
return ["innovation"]
361+
return (
362+
["innovation"]
363+
if self.k_endog == 1
364+
else [f"{name}_innovation" for name in self.endog_names]
365+
)
329366

330367
@property
331368
def param_dims(self):
@@ -339,11 +376,23 @@ def param_dims(self):
339376
"seasonal_param": (ETS_SEASONAL_DIM,),
340377
}
341378

379+
if self.dense_innovation_covariance:
380+
del coord_map["sigma_state"]
381+
coord_map["state_cov"] = (OBS_STATE_DIM, OBS_STATE_AUX_DIM)
382+
342383
if self.k_endog == 1:
343384
coord_map["sigma_state"] = None
344385
coord_map["sigma_obs"] = None
345386
coord_map["initial_level"] = None
346387
coord_map["initial_trend"] = None
388+
else:
389+
coord_map["alpha"] = (OBS_STATE_DIM,)
390+
coord_map["beta"] = (OBS_STATE_DIM,)
391+
coord_map["gamma"] = (OBS_STATE_DIM,)
392+
coord_map["phi"] = (OBS_STATE_DIM,)
393+
coord_map["initial_seasonal"] = (OBS_STATE_DIM, ETS_SEASONAL_DIM)
394+
coord_map["seasonal_param"] = (OBS_STATE_DIM, ETS_SEASONAL_DIM)
395+
347396
if not self.measurement_error:
348397
del coord_map["sigma_obs"]
349398
if not self.seasonal:
@@ -360,6 +409,8 @@ def coords(self) -> dict[str, Sequence]:
360409
return coords
361410

362411
def make_symbolic_graph(self) -> None:
412+
k_states_each = self.k_states // self.k_endog
413+
363414
P0 = self.make_and_register_variable(
364415
"P0", shape=(self.k_states, self.k_states), dtype=floatX
365416
)
@@ -368,21 +419,37 @@ def make_symbolic_graph(self) -> None:
368419
initial_level = self.make_and_register_variable(
369420
"initial_level", shape=(self.k_endog,) if self.k_endog > 1 else (), dtype=floatX
370421
)
371-
self.ssm["initial_state", 1] = initial_level
422+
423+
initial_states = [pt.zeros(k_states_each) for _ in range(self.k_endog)]
424+
if self.k_endog == 1:
425+
initial_states = [pt.set_subtensor(initial_states[0][1], initial_level)]
426+
else:
427+
initial_states = [
428+
pt.set_subtensor(initial_state[1], initial_level[i])
429+
for i, initial_state in enumerate(initial_states)
430+
]
372431

373432
# The shape of R can be pre-allocated, then filled with the required parameters
374-
R = pt.zeros((self.k_states, self.k_posdef))
433+
R = pt.zeros((self.k_states // self.k_endog, 1))
434+
435+
alpha = self.make_and_register_variable(
436+
"alpha", shape=() if self.k_endog == 1 else (self.k_endog,), dtype=floatX
437+
)
375438

376-
alpha = self.make_and_register_variable("alpha", shape=(), dtype=floatX)
377-
R = pt.set_subtensor(R[1, 0], alpha) # and l_t = ... + alpha * e_t
439+
if self.k_endog == 1:
440+
# The R[0, 0] entry needs to be adjusted for a shift in the time indices. Consider the (A, N, N) model:
441+
# y_t = l_{t-1} + e_t
442+
# l_t = l_{t-1} + alpha * e_t
443+
R_list = [pt.set_subtensor(R[1, 0], alpha)] # and l_t = ... + alpha * e_t
378444

379-
# The R[0, 0] entry needs to be adjusted for a shift in the time indices. Consider the (A, N, N) model:
380-
# y_t = l_{t-1} + e_t
381-
# l_t = l_{t-1} + alpha * e_t
382-
# We want the first equation to be in terms of time t on the RHS, because our observation equation is always
383-
# y_t = Z @ x_t. Re-arranging equation 2, we get l_{t-1} = l_t - alpha * e_t --> y_t = l_t + e_t - alpha * e_t
384-
# --> y_t = l_t + (1 - alpha) * e_t
385-
R = pt.set_subtensor(R[0, :], (1 - alpha))
445+
# We want the first equation to be in terms of time t on the RHS, because our observation equation is always
446+
# y_t = Z @ x_t. Re-arranging equation 2, we get l_{t-1} = l_t - alpha * e_t --> y_t = l_t + e_t - alpha * e_t
447+
# --> y_t = l_t + (1 - alpha) * e_t
448+
R_list = [pt.set_subtensor(R[0, :], (1 - alpha)) for R in R_list]
449+
else:
450+
# If there are multiple endog, clone the basic R matrix and modify the appropriate entries
451+
R_list = [pt.set_subtensor(R[1, 0], alpha[i]) for i in range(self.k_endog)]
452+
R_list = [pt.set_subtensor(R[0, :], (1 - alpha[i])) for i, R in enumerate(R_list)]
386453

387454
# Shock and level component always exists, the base case is e_t = e_t and l_t = l_{t-1}
388455
T_base = pt.as_tensor_variable(np.array([[0.0, 0.0], [0.0, 1.0]]))
@@ -391,77 +458,134 @@ def make_symbolic_graph(self) -> None:
391458
initial_trend = self.make_and_register_variable(
392459
"initial_trend", shape=(self.k_endog,) if self.k_endog > 1 else (), dtype=floatX
393460
)
394-
self.ssm["initial_state", 2] = initial_trend
395461

396-
beta = self.make_and_register_variable("beta", shape=(), dtype=floatX)
462+
if self.k_endog == 1:
463+
initial_states = [pt.set_subtensor(initial_states[0][2], initial_trend)]
464+
else:
465+
initial_states = [
466+
pt.set_subtensor(initial_state[2], initial_trend[i])
467+
for i, initial_state in enumerate(initial_states)
468+
]
469+
beta = self.make_and_register_variable(
470+
"beta", shape=() if self.k_endog == 1 else (self.k_endog,), dtype=floatX
471+
)
397472
if self.use_transformed_parameterization:
398-
R = pt.set_subtensor(R[2, 0], beta)
473+
param = beta
474+
else:
475+
param = alpha * beta
476+
if self.k_endog == 1:
477+
R_list = [pt.set_subtensor(R[2, 0], param) for R in R_list]
399478
else:
400-
R = pt.set_subtensor(R[2, 0], alpha * beta)
479+
R_list = [pt.set_subtensor(R[2, 0], param[i]) for i, R in enumerate(R_list)]
401480

402481
# If a trend is requested, we have the following transition equations (omitting the shocks):
403482
# l_t = l_{t-1} + b_{t-1}
404483
# b_t = b_{t-1}
405484
T_base = pt.as_tensor_variable(([0.0, 0.0, 0.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]))
406485

407486
if self.damped_trend:
408-
phi = self.make_and_register_variable("phi", shape=(), dtype=floatX)
487+
phi = self.make_and_register_variable(
488+
"phi", shape=() if self.k_endog == 1 else (self.k_endog,), dtype=floatX
489+
)
409490
# We are always in the case where we have a trend, so we can add the dampening parameter to T_base defined
410491
# in that branch. Transition equations become:
411492
# l_t = l_{t-1} + phi * b_{t-1}
412493
# b_t = phi * b_{t-1}
413-
T_base = pt.set_subtensor(T_base[1:, 2], phi)
494+
if self.k_endog > 1:
495+
T_base = [pt.set_subtensor(T_base[1:, 2], phi[i]) for i in range(self.k_endog)]
496+
else:
497+
T_base = pt.set_subtensor(T_base[1:, 2], phi)
414498

415-
T_components = [T_base]
499+
T_components = (
500+
[T_base for _ in range(self.k_endog)] if not isinstance(T_base, list) else T_base
501+
)
416502

417503
if self.seasonal:
418504
initial_seasonal = self.make_and_register_variable(
419-
"initial_seasonal", shape=(self.seasonal_periods,), dtype=floatX
505+
"initial_seasonal",
506+
shape=(self.seasonal_periods,)
507+
if self.k_endog == 1
508+
else (self.k_endog, self.seasonal_periods),
509+
dtype=floatX,
420510
)
421-
422-
self.ssm["initial_state", 2 + int(self.trend) :] = initial_seasonal
423-
424-
gamma = self.make_and_register_variable("gamma", shape=(), dtype=floatX)
425-
426-
if self.use_transformed_parameterization:
427-
param = gamma
511+
if self.k_endog == 1:
512+
initial_states = [
513+
pt.set_subtensor(initial_states[0][2 + int(self.trend) :], initial_seasonal)
514+
]
428515
else:
429-
param = (1 - alpha) * gamma
516+
initial_states = [
517+
pt.set_subtensor(initial_state[2 + int(self.trend) :], initial_seasonal[i])
518+
for i, initial_state in enumerate(initial_states)
519+
]
430520

431-
R = pt.set_subtensor(R[2 + int(self.trend), 0], param)
521+
gamma = self.make_and_register_variable(
522+
"gamma", shape=() if self.k_endog == 1 else (self.k_endog,), dtype=floatX
523+
)
432524

525+
param = gamma if self.use_transformed_parameterization else (1 - alpha) * gamma
433526
# Additional adjustment to the R[0, 0] position is required. Start from:
434527
# y_t = l_{t-1} + s_{t-m} + e_t
435528
# l_t = l_{t-1} + alpha * e_t
436529
# s_t = s_{t-m} + gamma * e_t
437530
# Solve for l_{t-1} and s_{t-m} in terms of l_t and s_t, then substitute into the observation equation:
438531
# y_t = l_t + s_t - alpha * e_t - gamma * e_t + e_t --> y_t = l_t + s_t + (1 - alpha - gamma) * e_t
439-
R = pt.set_subtensor(R[0, 0], R[0, 0] - param)
532+
533+
if self.k_endog == 1:
534+
R_list = [pt.set_subtensor(R[2 + int(self.trend), 0], param) for R in R_list]
535+
R_list = [pt.set_subtensor(R[0, 0], R[0, 0] - param) for R in R_list]
536+
537+
else:
538+
R_list = [
539+
pt.set_subtensor(R[2 + int(self.trend), 0], param[i])
540+
for i, R in enumerate(R_list)
541+
]
542+
R_list = [
543+
pt.set_subtensor(R[0, 0], R[0, 0] - param[i]) for i, R in enumerate(R_list)
544+
]
440545

441546
# The seasonal component is always going to look like a TimeFrequency structural component, see that
442547
# docstring for more details
443-
T_seasonal = pt.eye(self.seasonal_periods, k=-1)
444-
T_seasonal = pt.set_subtensor(T_seasonal[0, -1], 1.0)
445-
T_components += [T_seasonal]
548+
T_seasonals = [pt.eye(self.seasonal_periods, k=-1) for _ in range(self.k_endog)]
549+
T_seasonals = [pt.set_subtensor(T_seasonal[0, -1], 1.0) for T_seasonal in T_seasonals]
550+
551+
# Organize the components so it goes T1, T_seasonal_1, T2, T_seasonal_2, etc.
552+
T_components = [
553+
matrix[i] for i in range(self.k_endog) for matrix in [T_components, T_seasonals]
554+
]
446555

447-
self.ssm["selection"] = R
556+
x0 = pt.concatenate(initial_states, axis=0)
557+
R = pt.linalg.block_diag(*R_list)
558+
559+
self.ssm["initial_state"] = x0
560+
self.ssm["selection"] = pt.specify_shape(R, shape=(self.k_states, self.k_posdef))
448561

449562
T = pt.linalg.block_diag(*T_components)
450563
self.ssm["transition"] = pt.specify_shape(T, (self.k_states, self.k_states))
451564

452-
Z = np.zeros((self.k_endog, self.k_states))
453-
Z[0, 0] = 1.0 # innovation
454-
Z[0, 1] = 1.0 # level
455-
if self.seasonal:
456-
Z[0, 2 + int(self.trend)] = 1.0
565+
Zs = [np.zeros((self.k_endog, self.k_states // self.k_endog)) for _ in range(self.k_endog)]
566+
for i, Z in enumerate(Zs):
567+
Z[i, 0] = 1.0 # innovation
568+
Z[i, 1] = 1.0 # level
569+
if self.seasonal:
570+
Z[i, 2 + int(self.trend)] = 1.0
571+
572+
Z = pt.concatenate(Zs, axis=1)
573+
457574
self.ssm["design"] = Z
458575

459576
# Set up the state covariance matrix
460-
state_cov_idx = ("state_cov", *np.diag_indices(self.k_posdef))
461-
state_cov = self.make_and_register_variable(
462-
"sigma_state", shape=() if self.k_posdef == 1 else (self.k_posdef,), dtype=floatX
463-
)
464-
self.ssm[state_cov_idx] = state_cov**2
577+
if self.dense_innovation_covariance:
578+
state_cov = self.make_and_register_variable(
579+
"state_cov", shape=(self.k_posdef, self.k_posdef), dtype=floatX
580+
)
581+
self.ssm["state_cov"] = state_cov
582+
583+
else:
584+
state_cov_idx = ("state_cov", *np.diag_indices(self.k_posdef))
585+
state_cov = self.make_and_register_variable(
586+
"sigma_state", shape=() if self.k_posdef == 1 else (self.k_posdef,), dtype=floatX
587+
)
588+
self.ssm[state_cov_idx] = state_cov**2
465589

466590
if self.measurement_error:
467591
obs_cov_idx = ("obs_cov", *np.diag_indices(self.k_endog))

0 commit comments

Comments
 (0)