Skip to content

GLM-ordinal-regression should indicate extra dependencies #548

Open
@usptact

Description

@usptact

Notebook title: GLM-ordinal-regression
Notebook url: https://github.com/pymc-devs/pymc-examples/blob/main/examples/generalized_linear_models/GLM-ordinal-regression.ipynb

Issue description

Unable to run the cell 11 in the notebook. Gettting a jax error

/home/vlad/py310/lib/python3.10/site-packages/pymc/sampling/mcmc.py:243: UserWarning: Use of external NUTS sampler is still experimental
  warnings.warn("Use of external NUTS sampler is still experimental", UserWarning)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[11], line 36
     32     return idata, model
     35 priors = {"sigma": 1, "beta": [0, 1], "mu": np.linspace(0, K, K - 1)}
---> 36 idata1, model1 = make_model(priors, model_spec=1)
     37 idata2, model2 = make_model(priors, model_spec=2)
     38 idata3, model3 = make_model(priors, model_spec=3)

Cell In[11], line 30, in make_model(priors, model_spec, constrained_uniform, logit)
     28     else:
     29         y_ = pm.OrderedProbit("y", cutpoints=cutpoints, eta=mu, observed=df.explicit_rating)
---> 30     idata = pm.sample(nuts_sampler="numpyro", idata_kwargs={"log_likelihood": True})
     31     idata.extend(pm.sample_posterior_predictive(idata))
     32 return idata, model

File ~/py310/lib/python3.10/site-packages/pymc/sampling/mcmc.py:571, in sample(draws, tune, chains, cores, random_seed, progressbar, step, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model, **kwargs)
    567     if not isinstance(step, NUTS):
    568         raise ValueError(
    569             "Model can not be sampled with NUTS alone. Your model is probably not continuous."
    570         )
--> 571     return _sample_external_nuts(
    572         sampler=nuts_sampler,
    573         draws=draws,
    574         tune=tune,
    575         chains=chains,
    576         target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
    577         random_seed=random_seed,
    578         initvals=initvals,
    579         model=model,
    580         progressbar=progressbar,
    581         idata_kwargs=idata_kwargs,
    582         nuts_sampler_kwargs=nuts_sampler_kwargs,
    583         **kwargs,
    584     )
    586 if isinstance(step, list):
    587     step = CompoundStep(step)

File ~/py310/lib/python3.10/site-packages/pymc/sampling/mcmc.py:283, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, progressbar, idata_kwargs, nuts_sampler_kwargs, **kwargs)
    280     return idata
    282 elif sampler == "numpyro":
--> 283     import pymc.sampling.jax as pymc_jax
    285     idata = pymc_jax.sample_numpyro_nuts(
    286         draws=draws,
    287         tune=tune,
   (...)
    295         **nuts_sampler_kwargs,
    296     )
    297     return idata

File ~/py310/lib/python3.10/site-packages/pymc/sampling/jax.py:23
     20 from typing import Any, Callable, Dict, List, Optional, Sequence, Union
     22 import arviz as az
---> 23 import jax
     24 import numpy as np
     25 import pytensor.tensor as pt

File ~/py310/lib/python3.10/site-packages/jax/__init__.py:160
    158 from jax import abstract_arrays as abstract_arrays
    159 from jax import custom_derivatives as custom_derivatives
--> 160 from jax import custom_batching as custom_batching
    161 from jax import custom_transpose as custom_transpose
    162 from jax import api_util as api_util

File ~/py310/lib/python3.10/site-packages/jax/custom_batching.py:15
      1 # Copyright 2021 The JAX Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
---> 15 from jax._src.custom_batching import (
     16   custom_vmap,
     17   sequential_vmap,
     18 )

File ~/py310/lib/python3.10/site-packages/jax/_src/custom_batching.py:19
     16 import operator
     17 from typing import Callable, Optional
---> 19 from jax import lax
     20 from jax._src import api
     21 from jax._src import core

File ~/py310/lib/python3.10/site-packages/jax/lax/__init__.py:369
    363 from jax._src.lax.ann import (
    364   approx_max_k as approx_max_k,
    365   approx_min_k as approx_min_k,
    366   approx_top_k_p as approx_top_k_p
    367 )
    368 from jax._src.ad_util import stop_gradient_p as stop_gradient_p
--> 369 from jax.lax import linalg as linalg
    371 from jax._src.pjit import with_sharding_constraint as with_sharding_constraint
    372 from jax._src.pjit import sharding_constraint_p as sharding_constraint_p

File ~/py310/lib/python3.10/site-packages/jax/lax/linalg.py:15
      1 # Copyright 2020 The JAX Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
---> 15 from jax._src.lax.linalg import (
     16   cholesky,
     17   cholesky_p,
     18   eig,
     19   eig_p,
     20   eigh,
     21   eigh_p,
     22   hessenberg,
     23   hessenberg_p,
     24   lu,
     25   lu_p,
     26   lu_pivots_to_permutation,
     27   householder_product,
     28   householder_product_p,
     29   qr,
     30   qr_p,
     31   svd,
     32   svd_p,
     33   triangular_solve,
     34   triangular_solve_p,
     35   tridiagonal,
     36   tridiagonal_p,
     37   tridiagonal_solve,
     38   tridiagonal_solve_p,
     39   schur,
     40   schur_p
     41 )
     44 from jax._src.lax.qdwh import (
     45   qdwh as qdwh
     46 )

File ~/py310/lib/python3.10/site-packages/jax/_src/lax/linalg.py:37
     35 from jax._src.interpreters import mlir
     36 from jax._src.lax import control_flow
---> 37 from jax._src.lax import eigh as lax_eigh
     38 from jax._src.lax import lax as lax_internal
     39 from jax._src.lax import svd as lax_svd

File ~/py310/lib/python3.10/site-packages/jax/_src/lax/eigh.py:39
     37 from jax._src.numpy import ufuncs
     38 from jax import lax
---> 39 from jax._src.lax import qdwh
     40 from jax._src.lax import linalg as lax_linalg
     41 from jax._src.lax.stack import Stack

File ~/py310/lib/python3.10/site-packages/jax/_src/lax/qdwh.py:31
     28 from typing import Optional, Tuple
     30 import jax
---> 31 import jax.numpy as jnp
     32 from jax import lax
     33 from jax._src import core

File ~/py310/lib/python3.10/site-packages/jax/numpy/__init__.py:260
    257 # TODO(phawkins): make this import unconditional after increasing the ml_dtypes
    258 # minimum version.
    259 import jax._src.numpy.lax_numpy
--> 260 if hasattr(jax._src.numpy.lax_numpy, "int4"):
    261   from jax._src.numpy.lax_numpy import (
    262     int4 as int4,
    263     uint4 as uint4,
    264   )
    267 from jax._src.numpy.index_tricks import (
    268   c_ as c_,
    269   index_exp as index_exp,
   (...)
    273   s_ as s_,
    274 )

AttributeError: partially initialized module 'jax' has no attribute '_src' (most likely due to a circular import)

Note that this issue tracker is about the contents in the notebooks, if
the notebook is instead triggering a bug or error in pymc, please
report to https://github.com/pymc-devs/pymc/issues instead

Expected output

If applicable, describe what should happen instead.

Proposed solution

If applicable, explain possible solutions and workarounds.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions