From e2b5dd016f4bc53fb88c612196ae0325c0e75eb7 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Thu, 22 May 2025 14:44:11 +0800 Subject: [PATCH] Remove `mode` argument from everywhere --- pymc_extras/statespace/core/statespace.py | 33 ++----------------- .../statespace/filters/distributions.py | 8 ----- .../statespace/filters/kalman_filter.py | 19 +---------- .../statespace/filters/kalman_smoother.py | 8 ++--- pymc_extras/statespace/models/SARIMAX.py | 8 ++--- pymc_extras/statespace/models/VARMAX.py | 2 +- pymc_extras/statespace/models/structural.py | 8 ++--- tests/statespace/test_coord_assignment.py | 1 - tests/statespace/test_kalman_filter.py | 2 +- tests/statespace/test_statespace_JAX.py | 8 ++--- tests/statespace/utilities/test_helpers.py | 4 +-- 11 files changed, 19 insertions(+), 82 deletions(-) diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index 715808d81..a1d951548 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -14,7 +14,6 @@ from pymc.model.transform.optimization import freeze_dims_and_data from pymc.util import RandomState from pytensor import Variable, graph_replace -from pytensor.compile import get_mode from rich.box import SIMPLE_HEAD from rich.console import Console from rich.table import Table @@ -222,7 +221,6 @@ def __init__( verbose: bool = True, measurement_error: bool = False, ): - self._fit_mode: str | None = None self._fit_coords: dict[str, Sequence[str]] | None = None self._fit_dims: dict[str, Sequence[str]] | None = None self._fit_data: pt.TensorVariable | None = None @@ -819,7 +817,6 @@ def build_statespace_graph( self, data: np.ndarray | pd.DataFrame | pt.TensorVariable, register_data: bool = True, - mode: str | None = None, missing_fill_value: float | None = None, cov_jitter: float | None = JITTER_DEFAULT, save_kalman_filter_outputs_in_idata: bool = False, @@ -889,7 +886,6 @@ def build_statespace_graph( filter_outputs = self.kalman_filter.build_graph( pt.as_tensor_variable(data), *self.unpack_statespace(), - mode=mode, missing_fill_value=missing_fill_value, cov_jitter=cov_jitter, ) @@ -900,7 +896,7 @@ def build_statespace_graph( filtered_covariances, predicted_covariances, observed_covariances = covs if save_kalman_filter_outputs_in_idata: smooth_states, smooth_covariances = self._build_smoother_graph( - filtered_states, filtered_covariances, self.unpack_statespace(), mode=mode + filtered_states, filtered_covariances, self.unpack_statespace() ) all_kf_outputs = [*states, smooth_states, *covs, smooth_covariances] self._register_kalman_filter_outputs_with_pymc_model(all_kf_outputs) @@ -919,7 +915,6 @@ def build_statespace_graph( self._fit_coords = pm_mod.coords.copy() self._fit_dims = pm_mod.named_vars_to_dims.copy() - self._fit_mode = mode def _build_smoother_graph( self, @@ -964,7 +959,7 @@ def _build_smoother_graph( *_, T, Z, R, H, Q = matrices smooth_states, smooth_covariances = self.kalman_smoother.build_graph( - T, R, Q, filtered_states, filtered_covariances, mode=mode, cov_jitter=cov_jitter + T, R, Q, filtered_states, filtered_covariances, cov_jitter=cov_jitter ) smooth_states.name = "smooth_states" smooth_covariances.name = "smooth_covariances" @@ -1082,7 +1077,6 @@ def _kalman_filter_outputs_from_dummy_graph( R, H, Q, - mode=self._fit_mode, ) filter_outputs.pop(-1) @@ -1092,7 +1086,7 @@ def _kalman_filter_outputs_from_dummy_graph( filtered_covariances, predicted_covariances, _ = covariances [smoothed_states, smoothed_covariances] = self.kalman_smoother.build_graph( - T, R, Q, filtered_states, filtered_covariances, mode=self._fit_mode + T, R, Q, filtered_states, filtered_covariances ) grouped_outputs = [ @@ -1208,7 +1202,6 @@ def _sample_conditional( for name in FILTER_OUTPUT_TYPES for suffix in ["", "_observed"] ], - compile_kwargs={"mode": get_mode(self._fit_mode)}, random_seed=random_seed, **kwargs, ) @@ -1308,7 +1301,6 @@ def _sample_unconditional( *matrices, steps=steps, dims=dims, - mode=self._fit_mode, sequence_names=self.kalman_filter.seq_names, k_endog=self.k_endog, ) @@ -1323,7 +1315,6 @@ def _sample_unconditional( idata_unconditional = pm.sample_posterior_predictive( group_idata, var_names=[f"{group}_latent", f"{group}_observed"], - compile_kwargs={"mode": self._fit_mode}, random_seed=random_seed, **kwargs, ) @@ -1547,7 +1538,6 @@ def sample_statespace_matrices( matrix_idata = pm.sample_posterior_predictive( idata if group == "posterior" else idata.prior, var_names=matrix_names, - compile_kwargs={"mode": self._fit_mode}, extend_inferencedata=False, ) @@ -2094,7 +2084,6 @@ def forecast( *matrices, steps=len(forecast_index), dims=dims, - mode=self._fit_mode, sequence_names=self.kalman_filter.seq_names, k_endog=self.k_endog, append_x0=False, @@ -2109,7 +2098,6 @@ def forecast( idata_forecast = pm.sample_posterior_predictive( idata, var_names=["forecast_latent", "forecast_observed"], - compile_kwargs={"mode": self._fit_mode}, random_seed=random_seed, **kwargs, ) @@ -2260,28 +2248,13 @@ def irf_step(shock, x, c, T, R): non_sequences=[c, T, R], n_steps=n_steps, strict=True, - mode=self._fit_mode, ) pm.Deterministic("irf", irf, dims=[TIME_DIM, ALL_STATE_DIM]) - compile_kwargs = kwargs.get("compile_kwargs", {}) - if "mode" not in compile_kwargs.keys(): - compile_kwargs = {"mode": self._fit_mode} - else: - mode = compile_kwargs.get("mode") - if mode is not None and mode != self._fit_mode: - raise ValueError( - f"User provided compile mode ({mode}) does not match the compile mode used to " - f"construct the model ({self._fit_mode})." - ) - - compile_kwargs.update({"mode": self._fit_mode}) - irf_idata = pm.sample_posterior_predictive( idata, var_names=["irf"], - compile_kwargs=compile_kwargs, random_seed=random_seed, **kwargs, ) diff --git a/pymc_extras/statespace/filters/distributions.py b/pymc_extras/statespace/filters/distributions.py index 1e4f2b153..81f3815b3 100644 --- a/pymc_extras/statespace/filters/distributions.py +++ b/pymc_extras/statespace/filters/distributions.py @@ -69,7 +69,6 @@ def __new__( H, Q, steps=None, - mode=None, sequence_names=None, append_x0=True, **kwargs, @@ -97,7 +96,6 @@ def __new__( H, Q, steps=steps, - mode=mode, sequence_names=sequence_names, append_x0=append_x0, **kwargs, @@ -116,7 +114,6 @@ def dist( H, Q, steps=None, - mode=None, sequence_names=None, append_x0=True, **kwargs, @@ -132,7 +129,6 @@ def dist( return super().dist( [a0, P0, c, d, T, Z, R, H, Q, steps], - mode=mode, sequence_names=sequence_names, append_x0=append_x0, **kwargs, @@ -152,7 +148,6 @@ def rv_op( Q, steps, size=None, - mode=None, sequence_names=None, append_x0=True, ): @@ -235,7 +230,6 @@ def step_fn(*args): sequences=None if len(sequences) == 0 else sequences, non_sequences=[*non_sequences, rng], n_steps=steps, - mode=mode, strict=True, ) @@ -279,7 +273,6 @@ def __new__( steps, k_endog=None, sequence_names=None, - mode=None, append_x0=True, **kwargs, ): @@ -307,7 +300,6 @@ def __new__( H, Q, steps=steps, - mode=mode, sequence_names=sequence_names, append_x0=append_x0, **kwargs, diff --git a/pymc_extras/statespace/filters/kalman_filter.py b/pymc_extras/statespace/filters/kalman_filter.py index 0ca47b50e..e4cbc8bed 100644 --- a/pymc_extras/statespace/filters/kalman_filter.py +++ b/pymc_extras/statespace/filters/kalman_filter.py @@ -5,7 +5,6 @@ import pytensor.tensor as pt from pymc.pytensorf import constant_fold -from pytensor.compile.mode import get_mode from pytensor.graph.basic import Variable from pytensor.raise_op import Assert from pytensor.tensor import TensorVariable @@ -28,15 +27,10 @@ class BaseFilter(ABC): - def __init__(self, mode=None): + def __init__(self): """ Kalman Filter. - Parameters - ---------- - mode : str, optional - The mode used for Pytensor compilation. Defaults to None. - Notes ----- The BaseFilter class is an abstract base class (ABC) for implementing kalman filters. @@ -44,9 +38,6 @@ def __init__(self, mode=None): Attributes ---------- - mode : str or None - The mode used for Pytensor compilation. - seq_names : list[str] A list of name representing time-varying statespace matrices. That is, inputs that will need to be provided to the `sequences` argument of `pytensor.scan` @@ -56,7 +47,6 @@ def __init__(self, mode=None): to the `non_sequences` argument of `pytensor.scan` """ - self.mode: str = mode self.seq_names: list[str] = [] self.non_seq_names: list[str] = [] @@ -153,7 +143,6 @@ def build_graph( R, H, Q, - mode=None, return_updates=False, missing_fill_value=None, cov_jitter=None, @@ -166,9 +155,6 @@ def build_graph( data : TensorVariable Data to be filtered - mode : optional, str - Pytensor compile mode, passed to pytensor.scan - return_updates: bool, default False Whether to return updates associated with the pytensor scan. Should only be requried to debug pruposes. @@ -199,7 +185,6 @@ def build_graph( if cov_jitter is None: cov_jitter = JITTER_DEFAULT - self.mode = mode self.missing_fill_value = missing_fill_value self.cov_jitter = cov_jitter @@ -227,7 +212,6 @@ def build_graph( outputs_info=[None, a0, None, None, P0, None, None], non_sequences=non_sequences, name="forward_kalman_pass", - mode=get_mode(self.mode), strict=False, ) @@ -800,7 +784,6 @@ def kalman_step(self, y, a, P, c, d, T, Z, R, H, Q): self._univariate_inner_filter_step, sequences=[y_masked, Z_masked, d, pt.diag(H_masked), nan_mask], outputs_info=[a, P, None, None, None], - mode=get_mode(self.mode), name="univariate_inner_scan", ) diff --git a/pymc_extras/statespace/filters/kalman_smoother.py b/pymc_extras/statespace/filters/kalman_smoother.py index f15913b86..b22473391 100644 --- a/pymc_extras/statespace/filters/kalman_smoother.py +++ b/pymc_extras/statespace/filters/kalman_smoother.py @@ -1,7 +1,6 @@ import pytensor import pytensor.tensor as pt -from pytensor.compile import get_mode from pytensor.tensor.nlinalg import matrix_dot from pymc_extras.statespace.filters.utilities import ( @@ -18,8 +17,7 @@ class KalmanSmoother: """ - def __init__(self, mode: str | None = None): - self.mode = mode + def __init__(self): self.cov_jitter = JITTER_DEFAULT self.seq_names = [] self.non_seq_names = [] @@ -64,9 +62,8 @@ def unpack_args(self, args): return a, P, a_smooth, P_smooth, T, R, Q def build_graph( - self, T, R, Q, filtered_states, filtered_covariances, mode=None, cov_jitter=JITTER_DEFAULT + self, T, R, Q, filtered_states, filtered_covariances, cov_jitter=JITTER_DEFAULT ): - self.mode = mode self.cov_jitter = cov_jitter n, k = filtered_states.type.shape @@ -88,7 +85,6 @@ def build_graph( non_sequences=non_sequences, go_backwards=True, name="kalman_smoother", - mode=get_mode(self.mode), ) smoothed_states, smoothed_covariances = smoother_result diff --git a/pymc_extras/statespace/models/SARIMAX.py b/pymc_extras/statespace/models/SARIMAX.py index f8a420375..6981b972b 100644 --- a/pymc_extras/statespace/models/SARIMAX.py +++ b/pymc_extras/statespace/models/SARIMAX.py @@ -158,7 +158,7 @@ class BayesianSARIMA(PyMCStateSpace): rho = pm.Beta("ar_params", alpha=5, beta=1, dims=ss_mod.param_dims["ar_params"]) theta = pm.Normal("ma_params", mu=0.0, sigma=0.5, dims=ss_mod.param_dims["ma_params"]) - ss_mod.build_statespace_graph(df, mode="JAX") + ss_mod.build_statespace_graph(df) idata = pm.sample(nuts_sampler='numpyro') References @@ -366,7 +366,7 @@ def coords(self) -> dict[str, Sequence]: return coords - def _stationary_initialization(self, mode=None): + def _stationary_initialization(self): # Solve for matrix quadratic for P0 T = self.ssm["transition"] R = self.ssm["selection"] @@ -374,9 +374,7 @@ def _stationary_initialization(self, mode=None): c = self.ssm["state_intercept"] x0 = pt.linalg.solve(pt.identity_like(T) - T, c, assume_a="gen", check_finite=True) - - method = "direct" if (self.k_states < 5) or (mode == "JAX") else "bilinear" - P0 = solve_discrete_lyapunov(T, pt.linalg.matrix_dot(R, Q, R.T), method=method) + P0 = solve_discrete_lyapunov(T, pt.linalg.matrix_dot(R, Q, R.T), method="bilinear") return x0, P0 diff --git a/pymc_extras/statespace/models/VARMAX.py b/pymc_extras/statespace/models/VARMAX.py index b49765de0..d3ee2403d 100644 --- a/pymc_extras/statespace/models/VARMAX.py +++ b/pymc_extras/statespace/models/VARMAX.py @@ -135,7 +135,7 @@ class BayesianVARMAX(PyMCStateSpace): ar_params = pm.Normal("ar_params", mu=0, sigma=1, dims=ar_dims) state_cov = pm.Deterministic("state_cov", state_chol @ state_chol.T, dims=state_cov_dims) - bvar_mod.build_statespace_graph(data, mode="JAX") + bvar_mod.build_statespace_graph(data) idata = pm.sample(nuts_sampler="numpyro") """ diff --git a/pymc_extras/statespace/models/structural.py b/pymc_extras/statespace/models/structural.py index 40d1dedff..94f19d6c2 100644 --- a/pymc_extras/statespace/models/structural.py +++ b/pymc_extras/statespace/models/structural.py @@ -908,7 +908,7 @@ class MeasurementError(Component): intitial_trend = pm.Normal('initial_trend', sigma=10, dims=ss_mod.param_dims['initial_trend']) sigma_obs = pm.Exponential('sigma_obs', 1, dims=ss_mod.param_dims['sigma_obs']) - ss_mod.build_statespace_graph(data, mode='JAX') + ss_mod.build_statespace_graph(data) idata = pm.sample(nuts_sampler='numpyro') """ @@ -991,7 +991,7 @@ class AutoregressiveComponent(Component): ar_params = pm.Normal('ar_params', dims=ss_mod.param_dims['ar_params']) sigma_ar = pm.Exponential('sigma_ar', 1, dims=ss_mod.param_dims['sigma_ar']) - ss_mod.build_statespace_graph(data, mode='JAX') + ss_mod.build_statespace_graph(data) idata = pm.sample(nuts_sampler='numpyro') """ @@ -1153,7 +1153,7 @@ class TimeSeasonality(Component): intitial_trend = pm.Deterministic('initial_trend', pt.zeros(1), dims=ss_mod.param_dims['initial_trend']) annual_coefs = pm.Normal('annual_coefs', sigma=1e-2, dims=ss_mod.param_dims['annual_coefs']) trend_sigmas = pm.HalfNormal('trend_sigmas', sigma=1e-6, dims=ss_mod.param_dims['trend_sigmas']) - ss_mod.build_statespace_graph(data, mode='JAX') + ss_mod.build_statespace_graph(data) idata = pm.sample(nuts_sampler='numpyro') References @@ -1451,7 +1451,7 @@ class CycleComponent(Component): cycle_length = pm.Uniform('business_cycle_length', lower=6, upper=12) sigma_cycle = pm.HalfNormal('sigma_business_cycle', sigma=1) - ss_mod.build_statespace_graph(data, mode='JAX') + ss_mod.build_statespace_graph(data) idata = pm.sample(nuts_sampler='numpyro') diff --git a/tests/statespace/test_coord_assignment.py b/tests/statespace/test_coord_assignment.py index 40aaec124..4290c850e 100644 --- a/tests/statespace/test_coord_assignment.py +++ b/tests/statespace/test_coord_assignment.py @@ -137,7 +137,6 @@ def make_model(index): with pytest.warns(UserWarning, match="No time index found on the supplied data"): ss_mod.build_statespace_graph( a["A"], - mode="JAX", ) return model diff --git a/tests/statespace/test_kalman_filter.py b/tests/statespace/test_kalman_filter.py index 6c0bc18c6..8d5f63190 100644 --- a/tests/statespace/test_kalman_filter.py +++ b/tests/statespace/test_kalman_filter.py @@ -313,7 +313,7 @@ def test_kalman_filter_jax(filter): # TODO: Add UnivariateFilter to test; need to figure out the broadcasting issue when 2nd data dim is defined p, m, r, n = 1, 5, 1, 10 - inputs, outputs = initialize_filter(filter(), mode="JAX", p=p, m=m, r=r, n=n) + inputs, outputs = initialize_filter(filter(), p=p, m=m, r=r, n=n) inputs_np = make_test_inputs(p, m, r, n, rng) f_jax = get_jaxified_graph(inputs, outputs) diff --git a/tests/statespace/test_statespace_JAX.py b/tests/statespace/test_statespace_JAX.py index 9e8d9975f..74c0b5abe 100644 --- a/tests/statespace/test_statespace_JAX.py +++ b/tests/statespace/test_statespace_JAX.py @@ -37,9 +37,7 @@ def pymc_mod(ss_mod): rho = pm.Beta("rho", 1, 1) zeta = pm.Deterministic("zeta", 1 - rho) - ss_mod.build_statespace_graph( - data=nile, mode="JAX", save_kalman_filter_outputs_in_idata=True - ) + ss_mod.build_statespace_graph(data=nile, save_kalman_filter_outputs_in_idata=True) names = ["x0", "P0", "c", "d", "T", "Z", "R", "H", "Q"] for name, matrix in zip(names, ss_mod.unpack_statespace()): pm.Deterministic(name, matrix) @@ -62,7 +60,7 @@ def exog_pymc_mod(exog_ss_mod, rng): beta_exog = pm.Normal("beta_exog", dims=["exog_state"]) sigma_trend = pm.Exponential("sigma_trend", 1, dims=["trend_shock"]) - exog_ss_mod.build_statespace_graph(y, mode="JAX") + exog_ss_mod.build_statespace_graph(y) return m @@ -121,8 +119,6 @@ def test_no_nans_in_sampling_output(ss_mod, group, matrix, idata): @pytest.mark.parametrize("group", ["prior", "posterior"]) @pytest.mark.parametrize("kind", ["conditional", "unconditional"]) def test_sampling_methods(group, kind, ss_mod, idata, rng): - assert ss_mod._fit_mode == "JAX" - f = getattr(ss_mod, f"sample_{kind}_{group}") with pytest.warns(UserWarning, match="The RandomType SharedVariables"): test_idata = f(idata, random_seed=rng) diff --git a/tests/statespace/utilities/test_helpers.py b/tests/statespace/utilities/test_helpers.py index c6170f880..239affc46 100644 --- a/tests/statespace/utilities/test_helpers.py +++ b/tests/statespace/utilities/test_helpers.py @@ -34,7 +34,7 @@ def load_nile_test_data(): return nile -def initialize_filter(kfilter, mode=None, p=None, m=None, r=None, n=None): +def initialize_filter(kfilter, p=None, m=None, r=None, n=None): ksmoother = KalmanSmoother() data = pt.tensor(name="data", dtype=floatX, shape=(n, p)) a0 = pt.tensor(name="x0", dtype=floatX, shape=(m,)) @@ -57,7 +57,7 @@ def initialize_filter(kfilter, mode=None, p=None, m=None, r=None, n=None): predicted_covs, observed_covs, ll_obs, - ) = kfilter.build_graph(*inputs, mode=mode) + ) = kfilter.build_graph(*inputs) smoothed_states, smoothed_covs = ksmoother.build_graph(T, R, Q, filtered_states, filtered_covs)