Skip to content

Commit 381745e

Browse files
jessegrabowskiricardoV94
authored andcommitted
Scalar parameters in StructuralComponents are always expected to be scalar values, rather than 1d arrays
1 parent 1f70c7a commit 381745e

File tree

2 files changed

+40
-40
lines changed

2 files changed

+40
-40
lines changed

pymc_experimental/statespace/models/structural.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ def order_to_mask(order):
4040
def _frequency_transition_block(s, j):
4141
lam = 2 * np.pi * j / s
4242

43-
# Squeeze because otherwise if lamb has shape (1,), T will have shape (2, 2, 1)
44-
return pt.stack([[pt.cos(lam), pt.sin(lam)], [-pt.sin(lam), pt.cos(lam)]]).squeeze()
43+
return pt.stack([[pt.cos(lam), pt.sin(lam)], [-pt.sin(lam), pt.cos(lam)]])
4544

4645

4746
class StructuralTimeSeries(PyMCStateSpace):
@@ -914,7 +913,7 @@ def populate_component_properties(self):
914913
self.param_dims = {f"sigma_{self.name}": (OBS_STATE_DIM,)}
915914
self.param_info = {
916915
f"sigma_{self.name}": {
917-
"shape": (1,),
916+
"shape": (self.k_endog,),
918917
"constraints": "Positive",
919918
"dims": (OBS_STATE_DIM,),
920919
}
@@ -1015,13 +1014,13 @@ def populate_component_properties(self):
10151014
"constraints": None,
10161015
"dims": (AR_PARAM_DIM,),
10171016
},
1018-
"sigma_ar": {"shape": (1,), "constraints": "Positive", "dims": None},
1017+
"sigma_ar": {"shape": (), "constraints": "Positive", "dims": None},
10191018
}
10201019

10211020
def make_symbolic_graph(self) -> None:
10221021
k_nonzero = int(sum(self.order))
10231022
ar_params = self.make_and_register_variable("ar_params", shape=(k_nonzero,))
1024-
sigma_ar = self.make_and_register_variable("sigma_ar", shape=(1,))
1023+
sigma_ar = self.make_and_register_variable("sigma_ar", shape=())
10251024

10261025
T = np.eye(self.k_states, k=-1)
10271026
self.ssm["transition", :, :] = T
@@ -1194,7 +1193,7 @@ def populate_component_properties(self):
11941193
if self.innovations:
11951194
self.param_names += [f"sigma_{self.name}"]
11961195
self.param_info[f"sigma_{self.name}"] = {
1197-
"shape": (1,),
1196+
"shape": (),
11981197
"constraints": "Positive",
11991198
"dims": None,
12001199
}
@@ -1214,7 +1213,7 @@ def make_symbolic_graph(self) -> None:
12141213

12151214
if self.innovations:
12161215
self.ssm["selection", 0, 0] = 1
1217-
season_sigma = self.make_and_register_variable(f"sigma_{self.name}", shape=(1,))
1216+
season_sigma = self.make_and_register_variable(f"sigma_{self.name}", shape=())
12181217
cov_idx = ("state_cov", *np.diag_indices(1))
12191218
self.ssm[cov_idx] = season_sigma**2
12201219

@@ -1313,7 +1312,7 @@ def make_symbolic_graph(self) -> None:
13131312
self.ssm["transition", :, :] = T
13141313

13151314
if self.innovations:
1316-
sigma_season = self.make_and_register_variable(f"sigma_{self.name}", shape=(1,))
1315+
sigma_season = self.make_and_register_variable(f"sigma_{self.name}", shape=())
13171316
self.ssm["state_cov", :, :] = pt.eye(self.k_posdef) * sigma_season**2
13181317
self.ssm["selection", :, :] = np.eye(self.k_states)
13191318

@@ -1339,7 +1338,7 @@ def populate_component_properties(self):
13391338
self.shock_names = self.state_names.copy()
13401339
self.param_names += [f"sigma_{self.name}"]
13411340
self.param_info[f"sigma_{self.name}"] = {
1342-
"shape": (1,),
1341+
"shape": (),
13431342
"constraints": "Positive",
13441343
"dims": None,
13451344
}
@@ -1480,20 +1479,20 @@ def make_symbolic_graph(self) -> None:
14801479
self.ssm["initial_state", :] = init_state
14811480

14821481
if self.estimate_cycle_length:
1483-
lamb = self.make_and_register_variable(f"{self.name}_length", shape=(1,))
1482+
lamb = self.make_and_register_variable(f"{self.name}_length", shape=())
14841483
else:
14851484
lamb = self.cycle_length
14861485

14871486
if self.dampen:
1488-
rho = self.make_and_register_variable(f"{self.name}_dampening_factor", shape=(1,))
1487+
rho = self.make_and_register_variable(f"{self.name}_dampening_factor", shape=())
14891488
else:
14901489
rho = 1
14911490

14921491
T = rho * _frequency_transition_block(lamb, j=1)
14931492
self.ssm["transition", :, :] = T
14941493

14951494
if self.innovations:
1496-
sigma_cycle = self.make_and_register_variable(f"sigma_{self.name}", shape=(1,))
1495+
sigma_cycle = self.make_and_register_variable(f"sigma_{self.name}", shape=())
14971496
self.ssm["state_cov", :, :] = pt.eye(self.k_posdef) * sigma_cycle**2
14981497

14991498
def populate_component_properties(self):
@@ -1511,23 +1510,23 @@ def populate_component_properties(self):
15111510
if self.estimate_cycle_length:
15121511
self.param_names += [f"{self.name}_length"]
15131512
self.param_info[f"{self.name}_length"] = {
1514-
"shape": (1,),
1513+
"shape": (),
15151514
"constraints": "Positive, non-zero",
15161515
"dims": None,
15171516
}
15181517

15191518
if self.dampen:
15201519
self.param_names += [f"{self.name}_dampening_factor"]
15211520
self.param_info[f"{self.name}_dampening_factor"] = {
1522-
"shape": (1,),
1521+
"shape": (),
15231522
"constraints": "0 < x ≤ 1",
15241523
"dims": None,
15251524
}
15261525

15271526
if self.innovations:
15281527
self.param_names += [f"sigma_{self.name}"]
15291528
self.param_info[f"sigma_{self.name}"] = {
1530-
"shape": (1,),
1529+
"shape": (),
15311530
"constraints": "Positive",
15321531
"dims": None,
15331532
}
@@ -1609,7 +1608,11 @@ def populate_component_properties(self) -> None:
16091608
}
16101609

16111610
self.param_info = {
1612-
f"beta_{self.name}": {"shape": (1,), "constraints": None, "dims": ("exog_state",)},
1611+
f"beta_{self.name}": {
1612+
"shape": (self.k_states,),
1613+
"constraints": None,
1614+
"dims": ("exog_state",),
1615+
},
16131616
}
16141617

16151618
self.data_info = {
@@ -1624,7 +1627,7 @@ def populate_component_properties(self) -> None:
16241627
self.param_names += [f"sigma_beta_{self.name}"]
16251628
self.param_dims[f"sigma_beta_{self.name}"] = "exog_state"
16261629
self.param_info[f"sigma_beta_{self.name}"] = {
1627-
"shape": (1,),
1630+
"shape": (),
16281631
"constraints": "Positive",
16291632
"dims": ("exog_state",),
16301633
}

pymc_experimental/tests/statespace/test_structural.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def _assert_params_info_correct(param_info, coords, param_dims):
130130
else:
131131
inferred_dims = None
132132

133-
shape = tuple(len(label) for label in labels) if labels is not None else (1,)
133+
shape = tuple(len(label) for label in labels) if labels is not None else ()
134134

135135
assert info["shape"] == shape
136136
assert dims == inferred_dims
@@ -196,9 +196,9 @@ def create_structural_model_and_equivalent_statsmodel(
196196
components = []
197197

198198
if irregular:
199-
sigma2 = np.abs(rng.normal()).astype(floatX)
199+
sigma2 = np.abs(rng.normal()).astype(floatX).item()
200200
params["sigma_irregular"] = np.sqrt(sigma2)
201-
sm_params["sigma2.irregular"] = sigma2.item()
201+
sm_params["sigma2.irregular"] = sigma2
202202
expected_param_dims["sigma_irregular"] += ("observed_state",)
203203

204204
comp = st.MeasurementError("irregular")
@@ -298,7 +298,7 @@ def create_structural_model_and_equivalent_statsmodel(
298298
sm_init.update(seasonal_dict)
299299

300300
if stochastic_seasonal:
301-
sigma2 = np.abs(rng.normal(size=(1,))).astype(floatX)
301+
sigma2 = np.abs(rng.normal()).astype(floatX)
302302
params["sigma_seasonal"] = np.sqrt(sigma2)
303303
sm_params["sigma2.seasonal"] = sigma2
304304
expected_coords[SHOCK_DIM] += [
@@ -322,9 +322,6 @@ def create_structural_model_and_equivalent_statsmodel(
322322
n_states = 2 * n - int(last_state_not_identified)
323323
state_names = [f"seasonal_{s}_{f}_{i}" for i in range(n) for f in ["Cos", "Sin"]]
324324

325-
# if last_state_not_identified:
326-
# state_names.pop(-1)
327-
328325
seasonal_params = rng.normal(size=n_states).astype(floatX)
329326

330327
params[f"seasonal_{s}"] = seasonal_params
@@ -343,7 +340,7 @@ def create_structural_model_and_equivalent_statsmodel(
343340
state_count += 1
344341

345342
if has_innov:
346-
sigma2 = np.abs(rng.normal(size=(1,))).astype(floatX)
343+
sigma2 = np.abs(rng.normal()).astype(floatX)
347344
params[f"sigma_seasonal_{s}"] = np.sqrt(sigma2)
348345
sm_params[f"sigma2.freq_seasonal_{s}({n})"] = sigma2
349346
expected_coords[SHOCK_DIM] += state_names
@@ -359,7 +356,7 @@ def create_structural_model_and_equivalent_statsmodel(
359356

360357
# Statsmodels takes the frequency not the cycle length, so convert it.
361358
sm_params["frequency.cycle"] = 2.0 * np.pi / cycle_length
362-
params["cycle_length"] = np.atleast_1d(cycle_length)
359+
params["cycle_length"] = cycle_length
363360

364361
init_cycle = rng.normal(size=(2,)).astype(floatX)
365362
params["cycle"] = init_cycle
@@ -374,15 +371,15 @@ def create_structural_model_and_equivalent_statsmodel(
374371
sm_init["cycle.auxilliary"] = init_cycle[1]
375372

376373
if stochastic_cycle:
377-
sigma2 = np.abs(rng.normal(size=(1,))).astype(floatX)
374+
sigma2 = np.abs(rng.normal()).astype(floatX)
378375
params["sigma_cycle"] = np.sqrt(sigma2)
379376
expected_coords[SHOCK_DIM] += state_names
380377
expected_coords[SHOCK_AUX_DIM] += state_names
381378

382379
sm_params["sigma2.cycle"] = sigma2
383380

384381
if damped_cycle:
385-
rho = rng.beta(1, 1, size=(1,)).astype(floatX)
382+
rho = rng.beta(1, 1)
386383
params["cycle_dampening_factor"] = rho
387384
sm_params["damping.cycle"] = rho
388385

@@ -398,7 +395,9 @@ def create_structural_model_and_equivalent_statsmodel(
398395
if autoregressive is not None:
399396
ar_names = [f"L{i+1}.data" for i in range(autoregressive)]
400397
ar_params = rng.normal(size=(autoregressive,)).astype(floatX)
401-
sigma2 = np.abs(rng.normal(size=(1,))).astype(floatX)
398+
if autoregressive == 1:
399+
ar_params = ar_params.item()
400+
sigma2 = np.abs(rng.normal()).astype(floatX)
402401

403402
params["ar_params"] = ar_params
404403
params["sigma_ar"] = np.sqrt(sigma2)
@@ -550,8 +549,9 @@ def test_autoregressive_model(order, rng):
550549
ar = st.AutoregressiveComponent(order=order)
551550
params = {
552551
"ar_params": np.full((sum(ar.order),), 0.5, dtype=floatX),
553-
"sigma_ar": np.array([0.0], dtype=floatX),
552+
"sigma_ar": 0.0,
554553
}
554+
555555
x, y = simulate_from_numpy_model(ar, rng, params, steps=100)
556556

557557
# Check coords
@@ -578,7 +578,7 @@ def random_word(rng):
578578

579579
params = {"season_coefs": x0}
580580
if mod.innovations:
581-
params["sigma_season"] = np.array([0.0], dtype=floatX)
581+
params["sigma_season"] = 0.0
582582

583583
x, y = simulate_from_numpy_model(mod, rng, params)
584584
y = y.ravel()
@@ -604,7 +604,7 @@ def get_shift_factor(s):
604604
def test_frequency_seasonality(n, s, rng):
605605
mod = st.FrequencySeasonality(season_length=s, n=n, name="season")
606606
x0 = rng.normal(size=mod.n_coefs).astype(floatX)
607-
params = {"season": x0, "sigma_season": np.array([0.0], dtype=floatX)}
607+
params = {"season": x0, "sigma_season": 0.0}
608608
k = get_shift_factor(s)
609609
T = int(s * k)
610610

@@ -641,10 +641,7 @@ def test_cycle_component_with_dampening(rng):
641641
cycle = st.CycleComponent(
642642
name="cycle", cycle_length=12, estimate_cycle_length=False, innovations=False, dampen=True
643643
)
644-
params = {
645-
"cycle": np.array([10.0, 10.0], dtype=floatX),
646-
"cycle_dampening_factor": np.array([0.75], dtype=floatX),
647-
}
644+
params = {"cycle": np.array([10.0, 10.0], dtype=floatX), "cycle_dampening_factor": 0.75}
648645
x, y = simulate_from_numpy_model(cycle, rng, params, steps=100)
649646

650647
# Check that the cycle dampens to zero over time
@@ -657,9 +654,9 @@ def test_cycle_component_with_innovations_and_cycle_length(rng):
657654
)
658655
params = {
659656
"cycle": np.array([1.0, 1.0], dtype=floatX),
660-
"cycle_length": np.array([12], dtype=floatX),
661-
"cycle_dampening_factor": np.array([0.95], dtype=floatX),
662-
"sigma_cycle": np.array([1.0], dtype=floatX),
657+
"cycle_length": 12.0,
658+
"cycle_dampening_factor": 0.95,
659+
"sigma_cycle": 1.0,
663660
}
664661

665662
x, y = simulate_from_numpy_model(cycle, rng, params)
@@ -707,7 +704,7 @@ def test_add_components():
707704
}
708705
se_params = {
709706
"seasonal_coefs": np.ones(11, dtype=floatX),
710-
"sigma_seasonal": np.ones(1, dtype=floatX),
707+
"sigma_seasonal": 1.0,
711708
}
712709
all_params = ll_params.copy()
713710
all_params.update(se_params)

0 commit comments

Comments
 (0)