-
-
Notifications
You must be signed in to change notification settings - Fork 60
Remove mode
argument from Statespace models
#482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR removes the now-unnecessary mode argument from the scan calls and related bookkeeping in the statespace models. Key changes include:
- Updates in tests and models to remove mode argument passed to build_graph and related functions.
- Elimination of mode-based branching in several files, including Kalman filter and smoother implementations.
Reviewed Changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated no comments.
Show a summary per file
File | Description |
---|---|
tests/statespace/utilities/test_helpers.py | Removed mode parameter from initialize_filter and kfilter.build_graph calls. |
tests/statespace/test_statespace_JAX.py | Removed mode parameter from build_statespace_graph call and updated corresponding tests. |
tests/statespace/test_kalman_filter.py | Removed mode parameter from initialize_filter call. |
tests/statespace/test_coord_assignment.py | Removed mode parameter from build_statespace_graph call. |
pymc_extras/statespace/models/structural.py | Removed mode parameter from build_statespace_graph calls. |
pymc_extras/statespace/models/SARIMAX.py | Removed mode parameter and simplified the conditional logic in _stationary_initialization. |
pymc_extras/statespace/filters/* | Removed mode argument and related get_mode calls from kalman filter, smoother, and distributions. |
pymc_extras/statespace/core/statespace.py | Removed mode handling from graph building and sampling functions. |
Comments suppressed due to low confidence (1)
pymc_extras/statespace/models/SARIMAX.py:369
- The conditional logic for selecting the method for solve_discrete_lyapunov has been removed and now always uses 'bilinear'. Please verify that always using 'bilinear' is appropriate, especially for models with fewer states where the 'direct' method was previously used.
def _stationary_initialization(self):
Do we want a transition period with future warning to aid users? I feel people may actually be using the module these days |
Yes that's more professional. I was waffling on this point. I guess I can raise a warning if you pass mode to Alternatively the warning could say "this should now be set when you create the model", e.g. I think that's convenient, but if a user wants that he can also change the pytensor config flag. So I'm somewhat torn. |
Both together sound fine |
Closes #478
We were previously passing the
mode
argument to scan everywhere in statespace. This isn't strictly necessary, and required a fair amount of bookkeeping. This PR removes the mode argument from all the scans in the repo, and also removes all that bookkeeping.This will slightly break the user-facing API, because now
build_statespace_graph(data, mode='JAX')
will now raise an error (because you don't need to pass anything).It will also now require users to explicitly pass compile_kwargs to all of the sampling functions (
sample_conditional_posterior
,forecast
,impulse_response_function
, etc). That's consistent with all PyMC apis, though.I still need to go through and adjust all the notebooks, but it looks like tests are passing.