@@ -81,15 +81,18 @@ def __init__(
81
81
self ,
82
82
ssm : PytensorRepresentation ,
83
83
state_names ,
84
+ data_names ,
84
85
shock_names ,
85
86
param_names ,
86
87
exog_names ,
87
88
param_dims ,
88
89
coords ,
89
90
param_info ,
91
+ data_info ,
90
92
component_info ,
91
93
measurement_error ,
92
94
name_to_variable ,
95
+ name_to_data ,
93
96
name = None ,
94
97
verbose = True ,
95
98
filter_type : str = "standard" ,
@@ -104,6 +107,7 @@ def __init__(
104
107
param_names , param_dims , param_info , k_states
105
108
)
106
109
self ._state_names = state_names
110
+ self ._data_names = data_names
107
111
self ._shock_names = shock_names
108
112
self ._param_names = param_names
109
113
self ._param_dims = param_dims
@@ -113,6 +117,7 @@ def __init__(
113
117
114
118
self ._coords = coords
115
119
self ._param_info = param_info
120
+ self ._data_info = data_info
116
121
self .measurement_error = measurement_error
117
122
118
123
super ().__init__ (
@@ -126,7 +131,10 @@ def __init__(
126
131
127
132
self .ssm = ssm
128
133
self ._component_info = component_info
134
+
129
135
self ._name_to_variable = name_to_variable
136
+ self ._name_to_data = name_to_data
137
+
130
138
self ._exog_names = exog_names
131
139
self ._needs_exog_data = len (exog_names ) > 0
132
140
@@ -149,6 +157,10 @@ def _add_inital_state_cov_to_properties(param_names, param_dims, param_info, k_s
149
157
def param_names (self ):
150
158
return self ._param_names
151
159
160
+ @property
161
+ def data_names (self ) -> List [str ]:
162
+ return self ._data_names
163
+
152
164
@property
153
165
def state_names (self ):
154
166
return self ._state_names
@@ -173,6 +185,10 @@ def coords(self) -> Dict[str, Sequence]:
173
185
def param_info (self ) -> Dict [str , Dict [str , Any ]]:
174
186
return self ._param_info
175
187
188
+ @property
189
+ def data_info (self ) -> dict [str , dict [str , Any ]]:
190
+ return self ._data_info
191
+
176
192
def make_symbolic_graph (self ) -> None :
177
193
"""
178
194
Assign placeholder pytensor variables among statespace matrices in positions where PyMC variables will go.
@@ -338,6 +354,7 @@ def __init__(
338
354
k_states ,
339
355
k_posdef ,
340
356
state_names = None ,
357
+ data_names = None ,
341
358
shock_names = None ,
342
359
param_names = None ,
343
360
exog_names = None ,
@@ -354,14 +371,18 @@ def __init__(
354
371
self .measurement_error = measurement_error
355
372
356
373
self .state_names = state_names if state_names is not None else []
374
+ self .data_names = data_names if data_names is not None else []
357
375
self .shock_names = shock_names if shock_names is not None else []
358
376
self .param_names = param_names if param_names is not None else []
359
377
self .exog_names = exog_names if exog_names is not None else []
360
378
361
379
self .needs_exog_data = len (self .exog_names ) > 0
362
380
self .coords = {}
363
381
self .param_dims = {}
382
+
364
383
self .param_info = {}
384
+ self .data_info = {}
385
+
365
386
self .param_counts = {}
366
387
367
388
if representation is None :
@@ -370,6 +391,7 @@ def __init__(
370
391
self .ssm = representation
371
392
372
393
self ._name_to_variable = {}
394
+ self ._name_to_data = {}
373
395
374
396
if not component_from_sum :
375
397
self .populate_component_properties ()
@@ -429,6 +451,43 @@ def make_and_register_variable(self, name, shape, dtype=floatX) -> Variable:
429
451
self ._name_to_variable [name ] = placeholder
430
452
return placeholder
431
453
454
+ def make_and_register_data (self , name , shape , dtype = floatX ) -> Variable :
455
+ r"""
456
+ Helper function to create a pytensor symbolic variable and register it in the _name_to_data dictionary
457
+
458
+ Parameters
459
+ ----------
460
+ name : str
461
+ The name of the placeholder data. Must be the name of an expected data variable.
462
+ shape : int or tuple of int
463
+ Shape of the parameter
464
+ dtype : str, default pytensor.config.floatX
465
+ dtype of the parameter
466
+
467
+ Notes
468
+ -----
469
+ See docstring for make_and_register_variable for more details. This function is similar, but handles data
470
+ inputs instead of model parameters.
471
+
472
+ An error is raised if the provided name has already been registered, or if the name is not present in the
473
+ ``data_names`` property.
474
+ """
475
+ if name not in self .data_names :
476
+ raise ValueError (
477
+ f"{ name } is not a model parameter. All placeholder variables should correspond to model "
478
+ f"parameters."
479
+ )
480
+
481
+ if name in self ._name_to_data .keys ():
482
+ raise ValueError (
483
+ f"{ name } is already a registered placeholder variable with shape "
484
+ f"{ self ._name_to_data [name ].type .shape } "
485
+ )
486
+
487
+ placeholder = pt .tensor (name , shape = shape , dtype = dtype )
488
+ self ._name_to_data [name ] = placeholder
489
+ return placeholder
490
+
432
491
def make_symbolic_graph (self ) -> None :
433
492
raise NotImplementedError
434
493
@@ -481,7 +540,6 @@ def make_slice(name, x, o_x):
481
540
transition .name = T .name
482
541
483
542
design = pt .concatenate (conform_time_varying_and_time_invariant_matrices (Z , o_Z ), axis = - 1 )
484
-
485
543
design .name = Z .name
486
544
487
545
selection = block_diagonal ([R , o_R ])
@@ -542,14 +600,18 @@ def _make_combined_name(self):
542
600
543
601
def __add__ (self , other ):
544
602
state_names = self ._combine_property (other , "state_names" )
603
+ data_names = self ._combine_property (other , "data_names" )
545
604
param_names = self ._combine_property (other , "param_names" )
546
605
shock_names = self ._combine_property (other , "shock_names" )
547
606
param_info = self ._combine_property (other , "param_info" )
607
+ data_info = self ._combine_property (other , "data_info" )
548
608
param_dims = self ._combine_property (other , "param_dims" )
549
609
coords = self ._combine_property (other , "coords" )
550
610
exog_names = self ._combine_property (other , "exog_names" )
551
611
552
612
_name_to_variable = self ._combine_property (other , "_name_to_variable" )
613
+ _name_to_data = self ._combine_property (other , "_name_to_data" )
614
+
553
615
measurement_error = any ([self .measurement_error , other .measurement_error ])
554
616
555
617
k_states , k_posdef , k_endog = self ._get_combined_shapes (other )
@@ -567,30 +629,22 @@ def __add__(self, other):
567
629
new_comp ._component_info = self ._combine_component_info (other )
568
630
new_comp .name = new_comp ._make_combined_name ()
569
631
570
- property_names = [
571
- "state_names" ,
572
- "param_names" ,
573
- "shock_names" ,
574
- "state_dims" ,
575
- "coords" ,
576
- "param_dims" ,
577
- "param_info" ,
578
- "exog_names" ,
579
- "_name_to_variable" ,
580
- ]
581
- property_values = [
582
- state_names ,
583
- param_names ,
584
- shock_names ,
585
- param_dims ,
586
- coords ,
587
- param_dims ,
588
- param_info ,
589
- exog_names ,
590
- _name_to_variable ,
632
+ names_and_props = [
633
+ ("state_names" , state_names ),
634
+ ("data_names" , data_names ),
635
+ ("param_names" , param_names ),
636
+ ("shock_names" , shock_names ),
637
+ ("param_dims" , param_dims ),
638
+ ("coords" , coords ),
639
+ ("param_dims" , param_dims ),
640
+ ("param_info" , param_info ),
641
+ ("data_info" , data_info ),
642
+ ("exog_names" , exog_names ),
643
+ ("_name_to_variable" , _name_to_variable ),
644
+ ("_name_to_data" , _name_to_data ),
591
645
]
592
646
593
- for prop , value in zip ( property_names , property_values ) :
647
+ for prop , value in names_and_props :
594
648
setattr (new_comp , prop , value )
595
649
596
650
return new_comp
@@ -622,15 +676,18 @@ def build(self, name=None, filter_type="standard", verbose=True):
622
676
self .ssm ,
623
677
name = name ,
624
678
state_names = self .state_names ,
679
+ data_names = self .data_names ,
625
680
shock_names = self .shock_names ,
626
681
param_names = self .param_names ,
627
682
param_dims = self .param_dims ,
628
683
coords = self .coords ,
629
684
param_info = self .param_info ,
685
+ data_info = self .data_info ,
630
686
component_info = self ._component_info ,
631
687
measurement_error = self .measurement_error ,
632
688
exog_names = self .exog_names ,
633
689
name_to_variable = self ._name_to_variable ,
690
+ name_to_data = self ._name_to_data ,
634
691
filter_type = filter_type ,
635
692
verbose = verbose ,
636
693
)
@@ -881,7 +938,8 @@ def populate_component_properties(self):
881
938
}
882
939
883
940
def make_symbolic_graph (self ) -> None :
884
- error_sigma = self .make_and_register_variable (f"sigma_{ self .name } " , shape = (self .k_endog ,))
941
+ sigma_shape = () if self .k_endog == 1 else (self .k_endog ,)
942
+ error_sigma = self .make_and_register_variable (f"sigma_{ self .name } " , shape = sigma_shape )
885
943
diag_idx = np .diag_indices (self .k_endog )
886
944
idx = np .s_ ["obs_cov" , diag_idx [0 ], diag_idx [1 ]]
887
945
self .ssm [idx ] = error_sigma ** 2
@@ -1541,7 +1599,7 @@ def _handle_input_data(self, k_exog: int, state_names: Optional[List[str]], name
1541
1599
1542
1600
def make_symbolic_graph (self ) -> None :
1543
1601
betas = self .make_and_register_variable (f"beta_{ self .name } " , shape = (self .k_states ,))
1544
- regression_data = self .make_and_register_variable (
1602
+ regression_data = self .make_and_register_data (
1545
1603
f"data_{ self .name } " , shape = (None , self .k_states )
1546
1604
)
1547
1605
@@ -1560,17 +1618,19 @@ def make_symbolic_graph(self) -> None:
1560
1618
def populate_component_properties (self ) -> None :
1561
1619
self .shock_names = self .state_names
1562
1620
1563
- self .param_names = [f"beta_{ self .name } " , f"data_{ self .name } " ]
1621
+ self .param_names = [f"beta_{ self .name } " ]
1622
+ self .data_names = [f"data_{ self .name } " ]
1564
1623
self .param_dims = {
1565
1624
f"beta_{ self .name } " : ("exog_state" ,),
1566
- f"data_{ self .name } " : (TIME_DIM , "exog_state" ),
1567
1625
}
1568
1626
1569
1627
self .param_info = {
1570
1628
f"beta_{ self .name } " : {"shape" : (1 ,), "constraints" : None , "dims" : ("exog_state" ,)},
1629
+ }
1630
+
1631
+ self .data_info = {
1571
1632
f"data_{ self .name } " : {
1572
1633
"shape" : (None , self .k_states ),
1573
- "constraints" : None ,
1574
1634
"dims" : (TIME_DIM , "exog_state" ),
1575
1635
},
1576
1636
}
0 commit comments