Skip to content

External nuts sampler #560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 31, 2023
Merged

External nuts sampler #560

merged 7 commits into from
Aug 31, 2023

Conversation

twiecki
Copy link
Member

@twiecki twiecki commented Jul 11, 2023

Update previous JAX sampling NB

The previous NB was very outdated, I changed the example to be PPCA and update to usenuts_sampler kwarg.

@aseyboldt nutpie is kinda slow on this example, not sure why.

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@aseyboldt
Copy link
Member

Looks like this model is pretty much a benchmark of the matrix multiply speed in the different backends.
For me all three samplers take about 15s with MKL. What blas do you have installed?

I also get warnings from the numba backend:

/home/adr/git/nuts-py/python/nutpie/compile_pymc.py:364: NumbaPerformanceWarning: np.dot() is faster on contiguous arrays, called on (Array(float64, 2, 'A', False, aligned=True), Array(float64, 2, 'C', False, aligned=True))
  return inner(x)
/home/adr/git/nuts-py/python/nutpie/compile_pymc.py:364: NumbaPerformanceWarning: np.dot() is faster on contiguous arrays, called on (Array(float64, 2, 'C', False, aligned=True), Array(float64, 2, 'A', False, aligned=True))

Maybe the ordered transform introduces some non-contigous arrays?

@aseyboldt
Copy link
Member

aseyboldt commented Jul 11, 2023

@twiecki Could you maybe check what you get for sampling time alone and compile time alone for numba?

Ie

import nutpie

compiled = nutpie.compile_pymc_model(PPCA)
%time nutpie.sample(compiled)

(for me compilation is ~11s and sampling ~4s)

We should also use a fixed seed in the notebook, otherwise the data will be different each time we execute it. For me neither of the samplers ends up with a converged posterior, which makes comparing the times pretty pointless, but that might just be because of the seed I used...

@twiecki
Copy link
Member Author

twiecki commented Jul 11, 2023

I added a seed.

compilation time is rather low for this model.

Wall time: 33.4 s for compilation + sampling vs 47s for just sampling.

@twiecki
Copy link
Member Author

twiecki commented Jul 11, 2023

Cause of Numba slowness (as "debugged" with @aseyboldt just now): OpenBLAS. Installing Accelerate via micromamba install "libblas=*=*accelerate" got nutpie down to JAX-level speeds (minus compilation time).

@twiecki twiecki requested a review from OriolAbril July 12, 2023 09:15
Copy link
Member

@OriolAbril OriolAbril left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I have commented on the myst file, but make the changes to the ipynb file

@twiecki
Copy link
Member Author

twiecki commented Aug 22, 2023

@OriolAbril I have implemented the requested changes.

@twiecki twiecki requested review from OriolAbril August 22, 2023 14:45
@twiecki twiecki merged commit d123659 into main Aug 31, 2023
@twiecki twiecki deleted the external_nuts_sampler branch August 31, 2023 13:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants