Skip to content

Commit 93265aa

Browse files
Clean up type hints
Use built-in type hints in place of `typing.List` and `typing.Dict`
1 parent ec053e5 commit 93265aa

File tree

9 files changed

+102
-97
lines changed

9 files changed

+102
-97
lines changed

pymc_experimental/statespace/core/representation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional, Tuple, Type, Union
1+
from typing import Optional, Type, Union
22

33
import numpy as np
44
import pytensor
@@ -10,7 +10,7 @@
1010
)
1111

1212
floatX = pytensor.config.floatX
13-
KeyLike = Union[Tuple[Union[str, int]], str]
13+
KeyLike = Union[tuple[Union[str, int]], str]
1414

1515

1616
class PytensorRepresentation:
@@ -228,8 +228,8 @@ def _update_shape(self, key: KeyLike, value: Union[np.ndarray, pt.TensorType]) -
228228
self.shapes[key] = shape
229229

230230
def _add_time_dim_to_slice(
231-
self, name: str, slice_: Union[List[int], Tuple[int]], n_dim: int
232-
) -> Tuple[int]:
231+
self, name: str, slice_: Union[list[int], tuple[int]], n_dim: int
232+
) -> tuple[int]:
233233
# Case 1: There is never a time dim. No changes needed.
234234
if name in NEVER_TIME_VARYING:
235235
return slice_

pymc_experimental/statespace/core/statespace.py

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
2+
from typing import Any, Callable, Optional, Sequence, Union
33

44
import numpy as np
55
import pandas as pd
@@ -219,8 +219,8 @@ def __init__(
219219
measurement_error: bool = False,
220220
):
221221
self._fit_mode: Optional[str] = None
222-
self._fit_coords: Optional[dict[str, Sequence[str, ...]]] = None
223-
self._fit_dims: Optional[dict[str, Sequence[str, ...]]] = None
222+
self._fit_coords: Optional[dict[str, Sequence[str]]] = None
223+
self._fit_dims: Optional[dict[str, Sequence[str]]] = None
224224
self._fit_data: Optional[pt.TensorVariable] = None
225225

226226
self._needs_exog_data = False
@@ -237,7 +237,7 @@ def __init__(
237237
self.ssm = PytensorRepresentation(k_endog, k_states, k_posdef)
238238

239239
# This will be populated with PyMC random matrices after calling _insert_random_variables
240-
self.subbed_ssm: Optional[list[pt.TensorVariable, ...]] = None
240+
self.subbed_ssm: Optional[list[pt.TensorVariable]] = None
241241

242242
if filter_type.lower() not in FILTER_FACTORY.keys():
243243
raise NotImplementedError(
@@ -296,7 +296,7 @@ def _print_data_requirements(self) -> None:
296296
f"{out}"
297297
)
298298

299-
def _unpack_statespace_with_placeholders(self) -> Tuple:
299+
def _unpack_statespace_with_placeholders(self) -> tuple[pt.TensorVariable, ...]:
300300
"""
301301
Helper function to quickly obtain all statespace matrices in the standard order. Matrices returned by this
302302
method will include pytensor placeholders.
@@ -314,7 +314,7 @@ def _unpack_statespace_with_placeholders(self) -> Tuple:
314314

315315
return a0, P0, c, d, T, Z, R, H, Q
316316

317-
def unpack_statespace(self) -> list[pt.TensorVariable, ...]:
317+
def unpack_statespace(self) -> list[pt.TensorVariable]:
318318
"""
319319
Helper function to quickly obtain all statespace matrices in the standard order.
320320
"""
@@ -329,7 +329,7 @@ def unpack_statespace(self) -> list[pt.TensorVariable, ...]:
329329
return self.subbed_ssm
330330

331331
@property
332-
def param_names(self) -> List[str]:
332+
def param_names(self) -> list[str]:
333333
"""
334334
Names of model parameters
335335
@@ -339,7 +339,7 @@ def param_names(self) -> List[str]:
339339
raise NotImplementedError("The param_names property has not been implemented!")
340340

341341
@property
342-
def data_names(self) -> List[str]:
342+
def data_names(self) -> list[str]:
343343
"""
344344
Names of data variables expected by the model.
345345
@@ -349,7 +349,7 @@ def data_names(self) -> List[str]:
349349
raise NotImplementedError("The data_names property has not been implemented!")
350350

351351
@property
352-
def param_info(self) -> Dict[str, Dict[str, Any]]:
352+
def param_info(self) -> dict[str, dict[str, Any]]:
353353
"""
354354
Information about parameters needed to declare priors
355355
@@ -377,7 +377,7 @@ def data_info(self) -> dict[str, dict[str, Any]]:
377377
raise NotImplementedError("The data_info property has not been implemented!")
378378

379379
@property
380-
def state_names(self) -> List[str]:
380+
def state_names(self) -> list[str]:
381381
"""
382382
A k_states length list of strings, associated with the model's hidden states
383383
@@ -386,22 +386,22 @@ def state_names(self) -> List[str]:
386386
raise NotImplementedError("The state_names property has not been implemented!")
387387

388388
@property
389-
def observed_states(self) -> List[str]:
389+
def observed_states(self) -> list[str]:
390390
"""
391391
A k_endog length list of strings, associated with the model's observed states
392392
"""
393393
raise NotImplementedError("The observed_states property has not been implemented!")
394394

395395
@property
396-
def shock_names(self) -> List[str]:
396+
def shock_names(self) -> list[str]:
397397
"""
398398
A k_posdef length list of strings, associated with the model's shock processes
399399
400400
"""
401401
raise NotImplementedError("The shock_names property has not been implemented!")
402402

403403
@property
404-
def default_priors(self) -> Dict[str, Callable]:
404+
def default_priors(self) -> dict[str, Callable]:
405405
"""
406406
Dictionary of parameter names and callable functions to construct default priors for the model
407407
@@ -411,7 +411,7 @@ def default_priors(self) -> Dict[str, Callable]:
411411
raise NotImplementedError("The default_priors property has not been implemented!")
412412

413413
@property
414-
def coords(self) -> Dict[str, Sequence[str]]:
414+
def coords(self) -> dict[str, Sequence[str]]:
415415
"""
416416
PyMC model coordinates
417417
@@ -422,7 +422,7 @@ def coords(self) -> Dict[str, Sequence[str]]:
422422
raise NotImplementedError("The coords property has not been implemented!")
423423

424424
@property
425-
def param_dims(self) -> Dict[str, Sequence[str]]:
425+
def param_dims(self) -> dict[str, Sequence[str]]:
426426
"""
427427
Dictionary of named dimensions for each model parameter
428428
@@ -483,7 +483,9 @@ def make_and_register_variable(self, name, shape, dtype=floatX) -> Variable:
483483
self._name_to_variable[name] = placeholder
484484
return placeholder
485485

486-
def make_and_register_data(self, name, shape, dtype=floatX) -> Variable:
486+
def make_and_register_data(
487+
self, name: str, shape: Union[int, tuple[int]], dtype: str = floatX
488+
) -> Variable:
487489
r"""
488490
Helper function to create a pytensor symbolic variable and register it in the _name_to_data dictionary
489491
@@ -577,7 +579,9 @@ def make_symbolic_graph(self) -> None:
577579
"""
578580
raise NotImplementedError("The make_symbolic_statespace method has not been implemented!")
579581

580-
def _get_matrix_shape_and_dims(self, name: str) -> Tuple[Tuple, Tuple]:
582+
def _get_matrix_shape_and_dims(
583+
self, name: str
584+
) -> tuple[Optional[tuple[int]], Optional[tuple[str]]]:
581585
"""
582586
Get the shape and dimensions of a matrix associated with the specified name.
583587
@@ -614,7 +618,11 @@ def _get_matrix_shape_and_dims(self, name: str) -> Tuple[Tuple, Tuple]:
614618

615619
return shape, dims
616620

617-
def _get_output_shape_and_dims(self, idata: InferenceData, filter_output: str) -> Tuple:
621+
def _get_output_shape_and_dims(
622+
self, idata: InferenceData, filter_output: str
623+
) -> tuple[
624+
Optional[tuple[int]], Optional[tuple[int]], Optional[tuple[str]], Optional[tuple[str]]
625+
]:
618626
"""
619627
Get the shapes and dimensions of the output variables from the provided InferenceData.
620628
@@ -756,15 +764,15 @@ def _insert_data_variables(self):
756764
replacement_dict = {data: pymc_model[name] for name, data in self._name_to_data.items()}
757765
self.subbed_ssm = graph_replace(self.subbed_ssm, replace=replacement_dict, strict=True)
758766

759-
def _register_matrices_with_pymc_model(self) -> List[pt.TensorVariable]:
767+
def _register_matrices_with_pymc_model(self) -> list[pt.TensorVariable]:
760768
"""
761769
Add all statespace matrices to the PyMC model currently on the context stack as pm.Deterministic nodes, and
762770
adds named dimensions if they are found.
763771
764772
Returns
765773
-------
766774
registered_matrices: list of pt.TensorVariable
767-
List of statespace matrices, wrapped in pm.Deterministic
775+
list of statespace matrices, wrapped in pm.Deterministic
768776
"""
769777

770778
pm_mod = modelcontext(None)
@@ -788,9 +796,7 @@ def _register_matrices_with_pymc_model(self) -> List[pt.TensorVariable]:
788796
return registered_matrices
789797

790798
@staticmethod
791-
def _register_kalman_filter_outputs_with_pymc_model(
792-
outputs: tuple[pt.TensorVariable, ...]
793-
) -> None:
799+
def _register_kalman_filter_outputs_with_pymc_model(outputs: tuple[pt.TensorVariable]) -> None:
794800
mod = modelcontext(None)
795801
states, covs = outputs[:4], outputs[4:]
796802

@@ -1014,7 +1020,7 @@ def _build_dummy_graph(self) -> None:
10141020
10151021
Returns
10161022
-------
1017-
List[pm.Flat]
1023+
list[pm.Flat]
10181024
A list of pm.Flat variables representing all parameters estimated by the model.
10191025
"""
10201026
for name in self.param_names:
@@ -1026,7 +1032,7 @@ def _build_dummy_graph(self) -> None:
10261032

10271033
def _kalman_filter_outputs_from_dummy_graph(
10281034
self,
1029-
) -> tuple[list[pt.TensorVariable, ...], tuple[pt.TensorVariable, pt.TensorVariable]]:
1035+
) -> tuple[list[pt.TensorVariable], list[tuple[pt.TensorVariable, pt.TensorVariable]]]:
10301036
"""
10311037
Builds a Kalman filter graph using "dummy" pm.Flat distributions for the model variables and sorts the returns
10321038
into (mean, covariance) pairs for each of filtered, predicted, and smoothed output.
@@ -1379,9 +1385,9 @@ def sample_unconditional_prior(
13791385

13801386
def sample_unconditional_posterior(
13811387
self,
1382-
idata,
1383-
steps=None,
1384-
use_data_time_dim=False,
1388+
idata: InferenceData,
1389+
steps: Optional[int] = None,
1390+
use_data_time_dim: bool = False,
13851391
random_seed: Optional[RandomState] = None,
13861392
**kwargs,
13871393
) -> InferenceData:
@@ -1775,7 +1781,8 @@ def _sort_obs_inputs_by_time_varying(self, d, Z):
17751781

17761782
return seqs, non_seqs
17771783

1778-
def _sort_obs_scan_args(self, args):
1784+
@staticmethod
1785+
def _sort_obs_scan_args(args):
17791786
args = list(args)
17801787

17811788
# If a matrix is time-varying, pytensor will put a [t] on the name

0 commit comments

Comments
 (0)