Open
Description
I'm quite interested in the results of this paper. The authors derive closed-form gradients for backprop through Kalman Filters. Specifically equations 28-31.
They report a 38x speedup over autodiff gradients from PyTorch. I suspect (with no evidence) that the gradient computations are where the default PyMC sampler really fall down, so this might even make non-JAX sampling of SS models palatable.