Closed
Description
Describe the issue:
A model that sampled fine in 5.8.0 no longer works in 5.9.1 and throws a NotImplementedError (below)
The code might look a bit convoluted, but what it does is essentially a Gaussian Process over a time dimension that has very sparse values, making it a pretty useful construct in my models.
Reproduceable code example:
import pandas as pd, numpy as np
import pymc as pm
print(pm.__version__)
from pymc.sampling import jax as pm_jax
df = pd.DataFrame([
[-1.8, 'A'],
[-1.8, 'B'],
[-0.9, 'A'],
[-1.8, 'B'],
[-1.8, 'B'],
[-0.9, 'B'],
[-0.9, 'A']], columns=['t','response'])
times_idx, times = df["t"].factorize(sort=True)
resp_idx, responses = df['response'].factorize(sort=True)
COORDS = {
'time': times,
'response': responses,
'obs_idx': np.array(df.index)
}
with pm.Model(coords=COORDS) as h_multinomial_model:
obs = pm.MutableData( "obs", resp_idx, dims=("obs_idx") )
times_id = pm.MutableData("time_id", times_idx, dims="obs_idx")
gp_inp = pm.MutableData('time_vals',np.array(times),dims="time")[:,None]
ls = pm.Gamma(name='ls', alpha=5.0, beta=2.0)
c1 = pm.gp.cov.Matern52(ls=ls,input_dim=1)
gp_sds = pm.HalfNormal(f"σ_gp", 0.2, dims=('response',) )
α_time_offset = pm.MvNormal(f'α_time_offset', mu=0, cov=c1.full(gp_inp),dims=('response',"time"))
α_time = pm.Deterministic(f'α_time', (gp_sds[:,None]*α_time_offset).transpose(), dims=("time",'response') )
# likelihood
_ = pm.Categorical(
"y",
p=pm.math.softmax(α_time[times_id], axis=-1),
observed=obs,
dims=("obs_idx"),
)
idata = pm_jax.sample_numpyro_nuts()
Error message:
`NotImplementedError: No JAX conversion for the given `Op`: Blockwise{SolveTriangular{trans=0, unit_diagonal=False, lower=True, check_finite=True, b_ndim=1}, (m,m),(m)->(m)}`
PyMC version information:
PyMC 5.9.1
Context for the issue:
This is an issue with one of the main building blocks in the models I am working with, and while I can turn it off for the time being, it does add a lot of power to the models as it allows us to better model data that was gathered at different points in time.