Skip to content

Remove mode argument from Statespace models #482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jessegrabowski
Copy link
Member

Closes #478

We were previously passing the mode argument to scan everywhere in statespace. This isn't strictly necessary, and required a fair amount of bookkeeping. This PR removes the mode argument from all the scans in the repo, and also removes all that bookkeeping.

This will slightly break the user-facing API, because now build_statespace_graph(data, mode='JAX') will now raise an error (because you don't need to pass anything).

It will also now require users to explicitly pass compile_kwargs to all of the sampling functions (sample_conditional_posterior, forecast, impulse_response_function, etc). That's consistent with all PyMC apis, though.

I still need to go through and adjust all the notebooks, but it looks like tests are passing.

Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR removes the now-unnecessary mode argument from the scan calls and related bookkeeping in the statespace models. Key changes include:

  • Updates in tests and models to remove mode argument passed to build_graph and related functions.
  • Elimination of mode-based branching in several files, including Kalman filter and smoother implementations.

Reviewed Changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated no comments.

Show a summary per file
File Description
tests/statespace/utilities/test_helpers.py Removed mode parameter from initialize_filter and kfilter.build_graph calls.
tests/statespace/test_statespace_JAX.py Removed mode parameter from build_statespace_graph call and updated corresponding tests.
tests/statespace/test_kalman_filter.py Removed mode parameter from initialize_filter call.
tests/statespace/test_coord_assignment.py Removed mode parameter from build_statespace_graph call.
pymc_extras/statespace/models/structural.py Removed mode parameter from build_statespace_graph calls.
pymc_extras/statespace/models/SARIMAX.py Removed mode parameter and simplified the conditional logic in _stationary_initialization.
pymc_extras/statespace/filters/* Removed mode argument and related get_mode calls from kalman filter, smoother, and distributions.
pymc_extras/statespace/core/statespace.py Removed mode handling from graph building and sampling functions.
Comments suppressed due to low confidence (1)

pymc_extras/statespace/models/SARIMAX.py:369

  • The conditional logic for selecting the method for solve_discrete_lyapunov has been removed and now always uses 'bilinear'. Please verify that always using 'bilinear' is appropriate, especially for models with fewer states where the 'direct' method was previously used.
def _stationary_initialization(self):

@ricardoV94
Copy link
Member

ricardoV94 commented May 22, 2025

Do we want a transition period with future warning to aid users?

I feel people may actually be using the module these days

@jessegrabowski
Copy link
Member Author

Do we want a transition period with future warning to aid users?

I feel people may actually be using the module these days

Yes that's more professional. I was waffling on this point.

I guess I can raise a warning if you pass mode to build_statespace_graph and store it in the model, but not pass it to the scans. I could keep automatically setting it in the helper functions, but again with a warning that in the future we're not going to do this for you?

Alternatively the warning could say "this should now be set when you create the model", e.g. ss_mod = BayesianSARIMA(p=p, q=q, mode='JAX') would request that all sampling methods use compile_kwargs={'mode':'JAX'}, but it wouldn't do anything else.

I think that's convenient, but if a user wants that he can also change the pytensor config flag. So I'm somewhat torn.

@ricardoV94
Copy link
Member

I guess I can raise a warning if you pass mode to build_statespace_graph and store it in the model, but not pass it to the scans. I could keep automatically setting it in the helper functions, but again with a warning that in the future we're not going to do this for you?

Alternatively the warning could say ...

Both together sound fine

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Remove mode from statespace module functions
2 participants