Open
Description
Notebook title: GLM-ordinal-regression
Notebook url: https://github.com/pymc-devs/pymc-examples/blob/main/examples/generalized_linear_models/GLM-ordinal-regression.ipynb
Issue description
Unable to run the cell 11 in the notebook. Gettting a jax error
/home/vlad/py310/lib/python3.10/site-packages/pymc/sampling/mcmc.py:243: UserWarning: Use of external NUTS sampler is still experimental
warnings.warn("Use of external NUTS sampler is still experimental", UserWarning)
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[11], line 36
32 return idata, model
35 priors = {"sigma": 1, "beta": [0, 1], "mu": np.linspace(0, K, K - 1)}
---> 36 idata1, model1 = make_model(priors, model_spec=1)
37 idata2, model2 = make_model(priors, model_spec=2)
38 idata3, model3 = make_model(priors, model_spec=3)
Cell In[11], line 30, in make_model(priors, model_spec, constrained_uniform, logit)
28 else:
29 y_ = pm.OrderedProbit("y", cutpoints=cutpoints, eta=mu, observed=df.explicit_rating)
---> 30 idata = pm.sample(nuts_sampler="numpyro", idata_kwargs={"log_likelihood": True})
31 idata.extend(pm.sample_posterior_predictive(idata))
32 return idata, model
File ~/py310/lib/python3.10/site-packages/pymc/sampling/mcmc.py:571, in sample(draws, tune, chains, cores, random_seed, progressbar, step, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model, **kwargs)
567 if not isinstance(step, NUTS):
568 raise ValueError(
569 "Model can not be sampled with NUTS alone. Your model is probably not continuous."
570 )
--> 571 return _sample_external_nuts(
572 sampler=nuts_sampler,
573 draws=draws,
574 tune=tune,
575 chains=chains,
576 target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
577 random_seed=random_seed,
578 initvals=initvals,
579 model=model,
580 progressbar=progressbar,
581 idata_kwargs=idata_kwargs,
582 nuts_sampler_kwargs=nuts_sampler_kwargs,
583 **kwargs,
584 )
586 if isinstance(step, list):
587 step = CompoundStep(step)
File ~/py310/lib/python3.10/site-packages/pymc/sampling/mcmc.py:283, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, progressbar, idata_kwargs, nuts_sampler_kwargs, **kwargs)
280 return idata
282 elif sampler == "numpyro":
--> 283 import pymc.sampling.jax as pymc_jax
285 idata = pymc_jax.sample_numpyro_nuts(
286 draws=draws,
287 tune=tune,
(...)
295 **nuts_sampler_kwargs,
296 )
297 return idata
File ~/py310/lib/python3.10/site-packages/pymc/sampling/jax.py:23
20 from typing import Any, Callable, Dict, List, Optional, Sequence, Union
22 import arviz as az
---> 23 import jax
24 import numpy as np
25 import pytensor.tensor as pt
File ~/py310/lib/python3.10/site-packages/jax/__init__.py:160
158 from jax import abstract_arrays as abstract_arrays
159 from jax import custom_derivatives as custom_derivatives
--> 160 from jax import custom_batching as custom_batching
161 from jax import custom_transpose as custom_transpose
162 from jax import api_util as api_util
File ~/py310/lib/python3.10/site-packages/jax/custom_batching.py:15
1 # Copyright 2021 The JAX Authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
(...)
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
---> 15 from jax._src.custom_batching import (
16 custom_vmap,
17 sequential_vmap,
18 )
File ~/py310/lib/python3.10/site-packages/jax/_src/custom_batching.py:19
16 import operator
17 from typing import Callable, Optional
---> 19 from jax import lax
20 from jax._src import api
21 from jax._src import core
File ~/py310/lib/python3.10/site-packages/jax/lax/__init__.py:369
363 from jax._src.lax.ann import (
364 approx_max_k as approx_max_k,
365 approx_min_k as approx_min_k,
366 approx_top_k_p as approx_top_k_p
367 )
368 from jax._src.ad_util import stop_gradient_p as stop_gradient_p
--> 369 from jax.lax import linalg as linalg
371 from jax._src.pjit import with_sharding_constraint as with_sharding_constraint
372 from jax._src.pjit import sharding_constraint_p as sharding_constraint_p
File ~/py310/lib/python3.10/site-packages/jax/lax/linalg.py:15
1 # Copyright 2020 The JAX Authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
(...)
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
---> 15 from jax._src.lax.linalg import (
16 cholesky,
17 cholesky_p,
18 eig,
19 eig_p,
20 eigh,
21 eigh_p,
22 hessenberg,
23 hessenberg_p,
24 lu,
25 lu_p,
26 lu_pivots_to_permutation,
27 householder_product,
28 householder_product_p,
29 qr,
30 qr_p,
31 svd,
32 svd_p,
33 triangular_solve,
34 triangular_solve_p,
35 tridiagonal,
36 tridiagonal_p,
37 tridiagonal_solve,
38 tridiagonal_solve_p,
39 schur,
40 schur_p
41 )
44 from jax._src.lax.qdwh import (
45 qdwh as qdwh
46 )
File ~/py310/lib/python3.10/site-packages/jax/_src/lax/linalg.py:37
35 from jax._src.interpreters import mlir
36 from jax._src.lax import control_flow
---> 37 from jax._src.lax import eigh as lax_eigh
38 from jax._src.lax import lax as lax_internal
39 from jax._src.lax import svd as lax_svd
File ~/py310/lib/python3.10/site-packages/jax/_src/lax/eigh.py:39
37 from jax._src.numpy import ufuncs
38 from jax import lax
---> 39 from jax._src.lax import qdwh
40 from jax._src.lax import linalg as lax_linalg
41 from jax._src.lax.stack import Stack
File ~/py310/lib/python3.10/site-packages/jax/_src/lax/qdwh.py:31
28 from typing import Optional, Tuple
30 import jax
---> 31 import jax.numpy as jnp
32 from jax import lax
33 from jax._src import core
File ~/py310/lib/python3.10/site-packages/jax/numpy/__init__.py:260
257 # TODO(phawkins): make this import unconditional after increasing the ml_dtypes
258 # minimum version.
259 import jax._src.numpy.lax_numpy
--> 260 if hasattr(jax._src.numpy.lax_numpy, "int4"):
261 from jax._src.numpy.lax_numpy import (
262 int4 as int4,
263 uint4 as uint4,
264 )
267 from jax._src.numpy.index_tricks import (
268 c_ as c_,
269 index_exp as index_exp,
(...)
273 s_ as s_,
274 )
AttributeError: partially initialized module 'jax' has no attribute '_src' (most likely due to a circular import)
Note that this issue tracker is about the contents in the notebooks, if
the notebook is instead triggering a bug or error in pymc, please
report to https://github.com/pymc-devs/pymc/issues instead
Expected output
If applicable, describe what should happen instead.
Proposed solution
If applicable, explain possible solutions and workarounds.