Skip to content

Refactor jax internals to support dense_mass kwarg for numpyro #7050

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

ferrine
Copy link
Member

@ferrine ferrine commented Dec 6, 2023

What is this PR about?
Enables Block Dense mass matrix adaptation for numpyro

Checklist

Major / Breaking Changes

  • ...

New features

  • Block mass matrix for numpyro
  • get_jaxified_logp now accepts point_fn argument
with pm.Model(
        coords=dict(level=["Basement", "Floor"], county=[1, 2]),
) as model:
    # multilevel modelling
    a = pm.Normal("a")
    s = pm.HalfNormal("s")
    a_g = pm.Normal("a_g", a, s, dims="level")
    s_g = pm.HalfNormal("s_g")
    a_ig = pm.Normal("a_ig", a_g, s_g, dims=("county", "level"))
    trace = sample_numpyro_nuts(
        nuts_kwargs=dict(
            dense_mass=[
                ("a", "a_g"),
            ]
        )
    )

Bugfixes

  • ...

Documentation

  • ...

Maintenance

  • ...

📚 Documentation preview 📚: https://pymc--7050.org.readthedocs.build/en/7050/

Copy link

codecov bot commented Dec 6, 2023

Codecov Report

Merging #7050 (8eb4284) into main (2e05854) will decrease coverage by 12.23%.
The diff coverage is 0.00%.

Additional details and impacted files

Impacted file tree graph

@@             Coverage Diff             @@
##             main    #7050       +/-   ##
===========================================
- Coverage   92.19%   79.97%   -12.23%     
===========================================
  Files         101      101               
  Lines       16893    16911       +18     
===========================================
- Hits        15575    13524     -2051     
- Misses       1318     3387     +2069     
Files Coverage Δ
pymc/sampling/jax.py 0.00% <0.00%> (-93.08%) ⬇️

... and 31 files with indirect coverage changes

@ferrine ferrine requested a review from ricardoV94 December 6, 2023 19:24
@ricardoV94 ricardoV94 marked this pull request as draft December 10, 2023 11:36
@ricardoV94 ricardoV94 marked this pull request as ready for review December 11, 2023 13:55
@ricardoV94
Copy link
Member

ricardoV94 commented Dec 11, 2023

These failing tests are definitely a latest PyTensor issue, I'll patch it

@ricardoV94 ricardoV94 changed the title refactor jax internals to support dense_mass kwarg for numpyro Refactor jax internals to support dense_mass kwarg for numpyro Dec 11, 2023
@ricardoV94
Copy link
Member

Failing tests due to PyTensor should be fixed by pymc-devs/pytensor#546

@ricardoV94
Copy link
Member

@ferrine can you rebase?

@ferrine ferrine marked this pull request as draft July 9, 2024 09:09
@ferrine
Copy link
Member Author

ferrine commented Jul 9, 2024

The rebase did not went as smooth there, converting thit to draft

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants