Skip to content

BUG: Regression in JAX model ops #6993

Closed
@velochy

Description

@velochy

Describe the issue:

A model that sampled fine in 5.8.0 no longer works in 5.9.1 and throws a NotImplementedError (below)

The code might look a bit convoluted, but what it does is essentially a Gaussian Process over a time dimension that has very sparse values, making it a pretty useful construct in my models.

Reproduceable code example:

import pandas as pd, numpy as np
import pymc as pm
print(pm.__version__)
from pymc.sampling import jax as pm_jax


df = pd.DataFrame([
       [-1.8, 'A'],
       [-1.8, 'B'],
       [-0.9, 'A'],
       [-1.8, 'B'],
       [-1.8, 'B'],
       [-0.9, 'B'],
       [-0.9, 'A']], columns=['t','response'])


times_idx, times = df["t"].factorize(sort=True)
resp_idx, responses = df['response'].factorize(sort=True)

COORDS = { 
    'time': times,
    'response': responses,
    'obs_idx': np.array(df.index)
}

with pm.Model(coords=COORDS) as h_multinomial_model:
    
    obs = pm.MutableData( "obs", resp_idx, dims=("obs_idx") )

    times_id = pm.MutableData("time_id", times_idx, dims="obs_idx")
    gp_inp = pm.MutableData('time_vals',np.array(times),dims="time")[:,None]

    ls = pm.Gamma(name='ls', alpha=5.0, beta=2.0)
    c1 = pm.gp.cov.Matern52(ls=ls,input_dim=1)

    gp_sds = pm.HalfNormal(f"σ_gp", 0.2, dims=('response',) )   
    α_time_offset = pm.MvNormal(f'α_time_offset', mu=0, cov=c1.full(gp_inp),dims=('response',"time"))
    α_time = pm.Deterministic(f'α_time', (gp_sds[:,None]*α_time_offset).transpose(), dims=("time",'response') )

        # likelihood
    _ = pm.Categorical(
        "y",
        p=pm.math.softmax(α_time[times_id], axis=-1),
        observed=obs,
        dims=("obs_idx"),
    )
    
    idata = pm_jax.sample_numpyro_nuts()

Error message:

`NotImplementedError: No JAX conversion for the given `Op`: Blockwise{SolveTriangular{trans=0, unit_diagonal=False, lower=True, check_finite=True, b_ndim=1}, (m,m),(m)->(m)}`

PyMC version information:

PyMC 5.9.1

Context for the issue:

This is an issue with one of the main building blocks in the models I am working with, and while I can turn it off for the time being, it does add a lot of power to the models as it allows us to better model data that was gathered at different points in time.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions