Skip to content

Numba dispatch of PyMC LKJ priors #432

Closed
@larryshamalama

Description

@larryshamalama

Description

I am messing with PyMC LKJ priors for the variance-covariance matrix of multivariate normals. It would be nice to have nutpie be able to sample this as the current PyMC sampler seems slow, at least for my simulations. This is likely to be a duplicate issue...

import pymc as pm
import numpy as np

import nutpie

import pytensor
import pytensor.tensor as pt

mu = np.array([-3, -2, -1, 0, 1, 2, 3])
n = len(mu)

with pm.Model() as model:
    sd_dist = pm.Exponential.dist(1.0, size=n)
    chol, corr, sigmas = pm.LKJCholeskyCov(
        "chol_cov",
        eta=4,
        n=n,
        sd_dist=sd_dist,
    )

    vals = pm.MvNormal("vals", mu=mu, chol=chol, size=13)
    
compiled_model = nutpie.compile_pymc_model(model)
trace_pymc = nutpie.sample(compiled_model)

yields:

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[3], line 12
      3     chol, corr, sigmas = pm.LKJCholeskyCov(
      4         "chol_cov",
      5         eta=4,
      6         n=n,
      7         sd_dist=sd_dist,
      8     )
     10     vals = pm.MvNormal("vals", mu=mu, chol=chol, size=13)
---> 12 compiled_model = nutpie.compile_pymc_model(model)
     13 trace_pymc = nutpie.sample(compiled_model)

File ~/miniforge3/envs/phd/lib/python3.11/site-packages/nutpie/compile_pymc.py:171, in compile_pymc_model(model, **kwargs)
    147 def compile_pymc_model(model: pm.Model, **kwargs) -> CompiledPyMCModel:
    148     """Compile necessary functions for sampling a pymc model.
    149 
    150     Parameters
   (...)
    159 
    160     """
    162     (
    163         n_dim,
    164         n_expanded,
    165         logp_fn_pt,
    166         logp_fn,
    167         expand_fn_pt,
    168         expand_fn,
    169         shared_expand,
    170         shape_info,
--> 171     ) = _make_functions(model)
    173     shared_data = {val.name: val.get_value().copy() for val in logp_fn_pt.get_shared()}
    174     for val in shared_data.values():

File ~/miniforge3/envs/phd/lib/python3.11/site-packages/nutpie/compile_pymc.py:312, in _make_functions(model)
    309 (logp, grad) = pytensor.graph_replace([logp, grad], replacements)
    311 # We should avoid compiling the function, and optimize only
--> 312 logp_fn_pt = pytensor.compile.function.function(
    313     (joined,), (logp, grad), mode=pytensor.compile.NUMBA
    314 )
    316 logp_fn = logp_fn_pt.vm.jit_fn
    318 # Make function that computes remaining variables for the trace

File ~/miniforge3/envs/phd/lib/python3.11/site-packages/pytensor/compile/function/__init__.py:315, in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input)
    309     fn = orig_function(
    310         inputs, outputs, mode=mode, accept_inplace=accept_inplace, name=name
    311     )
    312 else:
    313     # note: pfunc will also call orig_function -- orig_function is
    314     #      a choke point that all compilation must pass through
--> 315     fn = pfunc(
    316         params=inputs,
    317         outputs=outputs,
    318         mode=mode,
    319         updates=updates,
    320         givens=givens,
    321         no_default_updates=no_default_updates,
    322         accept_inplace=accept_inplace,
    323         name=name,
    324         rebuild_strict=rebuild_strict,
    325         allow_input_downcast=allow_input_downcast,
    326         on_unused_input=on_unused_input,
    327         profile=profile,
    328         output_keys=output_keys,
    329     )
    330 return fn

File ~/miniforge3/envs/phd/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:468, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph)
    454     profile = ProfileStats(message=profile)
    456 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
    457     params,
    458     outputs,
   (...)
    465     fgraph=fgraph,
    466 )
--> 468 return orig_function(
    469     inputs,
    470     cloned_outputs,
    471     mode,
    472     accept_inplace=accept_inplace,
    473     name=name,
    474     profile=profile,
    475     on_unused_input=on_unused_input,
    476     output_keys=output_keys,
    477     fgraph=fgraph,
    478 )

File ~/miniforge3/envs/phd/lib/python3.11/site-packages/pytensor/compile/function/types.py:1756, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph)
   1744     m = Maker(
   1745         inputs,
   1746         outputs,
   (...)
   1753         fgraph=fgraph,
   1754     )
   1755     with config.change_flags(compute_test_value="off"):
-> 1756         fn = m.create(defaults)
   1757 finally:
   1758     t2 = time.perf_counter()

File ~/miniforge3/envs/phd/lib/python3.11/site-packages/pytensor/compile/function/types.py:1649, in FunctionMaker.create(self, input_storage, storage_map)
   1646 start_import_time = pytensor.link.c.cmodule.import_time
   1648 with config.change_flags(traceback__limit=config.traceback__compile_limit):
-> 1649     _fn, _i, _o = self.linker.make_thunk(
   1650         input_storage=input_storage_lists, storage_map=storage_map
   1651     )
   1653 end_linker = time.perf_counter()
   1655 linker_time = end_linker - start_linker

File ~/miniforge3/envs/phd/lib/python3.11/site-packages/pytensor/link/basic.py:254, in LocalLinker.make_thunk(self, input_storage, output_storage, storage_map, **kwargs)
    247 def make_thunk(
    248     self,
    249     input_storage: Optional["InputStorageType"] = None,
   (...)
    252     **kwargs,
    253 ) -> Tuple["BasicThunkType", "InputStorageType", "OutputStorageType"]:
--> 254     return self.make_all(
    255         input_storage=input_storage,
    256         output_storage=output_storage,
    257         storage_map=storage_map,
    258     )[:3]

File ~/miniforge3/envs/phd/lib/python3.11/site-packages/pytensor/link/basic.py:697, in JITLinker.make_all(self, input_storage, output_storage, storage_map)
    694 for k in storage_map:
    695     compute_map[k] = [k.owner is None]
--> 697 thunks, nodes, jit_fn = self.create_jitable_thunk(
    698     compute_map, nodes, input_storage, output_storage, storage_map
    699 )
    701 computed, last_user = gc_helper(nodes)
    703 if self.allow_gc:

File ~/miniforge3/envs/phd/lib/python3.11/site-packages/pytensor/link/basic.py:647, in JITLinker.create_jitable_thunk(self, compute_map, order, input_storage, output_storage, storage_map)
    644 # This is a bit hackish, but we only return one of the output nodes
    645 output_nodes = [o.owner for o in self.fgraph.outputs if o.owner is not None][:1]
--> 647 converted_fgraph = self.fgraph_convert(
    648     self.fgraph,
    649     order=order,
    650     input_storage=input_storage,
    651     output_storage=output_storage,
    652     storage_map=storage_map,
    653 )
    655 thunk_inputs = self.create_thunk_inputs(storage_map)
    657 thunks = []

File ~/miniforge3/envs/phd/lib/python3.11/site-packages/pytensor/link/numba/linker.py:27, in NumbaLinker.fgraph_convert(self, fgraph, **kwargs)
     24 def fgraph_convert(self, fgraph, **kwargs):
     25     from pytensor.link.numba.dispatch import numba_funcify
---> 27     return numba_funcify(fgraph, **kwargs)

File ~/miniforge3/envs/phd/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
    905 if not args:
    906     raise TypeError(f'{funcname} requires at least '
    907                     '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)

File ~/miniforge3/envs/phd/lib/python3.11/site-packages/pytensor/link/numba/dispatch/basic.py:459, in numba_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
    452 @numba_funcify.register(FunctionGraph)
    453 def numba_funcify_FunctionGraph(
    454     fgraph,
   (...)
    457     **kwargs,
    458 ):
--> 459     return fgraph_to_python(
    460         fgraph,
    461         numba_funcify,
    462         type_conversion_fn=numba_typify,
    463         fgraph_name=fgraph_name,
    464         **kwargs,
    465     )

File ~/miniforge3/envs/phd/lib/python3.11/site-packages/pytensor/link/utils.py:738, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
    736 body_assigns = []
    737 for node in order:
--> 738     compiled_func = op_conversion_fn(
    739         node.op, node=node, storage_map=storage_map, **kwargs
    740     )
    742     # Create a local alias with a unique name
    743     local_compiled_func_name = unique_name(compiled_func)

File ~/miniforge3/envs/phd/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
    905 if not args:
    906     raise TypeError(f'{funcname} requires at least '
    907                     '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)

File ~/miniforge3/envs/phd/lib/python3.11/site-packages/pytensor/link/numba/dispatch/elemwise.py:755, in numba_funcify_CAReduce(op, node, **kwargs)
    753 input_name = get_name_for_object(node.inputs[0])
    754 ndim = node.inputs[0].ndim
--> 755 careduce_py_fn = create_multiaxis_reducer(
    756     op.scalar_op,
    757     scalar_op_identity,
    758     axes,
    759     ndim,
    760     np.dtype(node.outputs[0].type.dtype),
    761     input_name=input_name,
    762 )
    764 careduce_fn = jit_compile_reducer(node, careduce_py_fn, reduce_to_scalar=False)
    765 return careduce_fn

File ~/miniforge3/envs/phd/lib/python3.11/site-packages/pytensor/link/numba/dispatch/elemwise.py:358, in create_multiaxis_reducer(scalar_op, identity, axes, ndim, dtype, input_name, return_scalar)
    320 r"""Construct a function that reduces multiple axes.
    321 
    322 The functions generated by this function take the following form:
   (...)
    355 
    356 """
    357 if len(axes) == 1:
--> 358     return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype)
    360 axes = normalize_axis_tuple(axes, ndim)
    362 careduce_fn_name = f"careduce_{scalar_op}"

File ~/miniforge3/envs/phd/lib/python3.11/site-packages/pytensor/link/numba/dispatch/elemwise.py:285, in create_axis_reducer(scalar_op, identity, axis, ndim, dtype, keepdims, return_scalar)
    269         reduce_elemwise_def_src = f"""
    270 def {reduce_elemwise_fn_name}(x):
    271 
   (...)
    282     return {return_expr}
    283         """
    284     else:
--> 285         inplace_update_statement = scalar_in_place_fn(scalar_op, "0", "res", "x[i]")
    286         inplace_update_statement = indent(inplace_update_statement, " " * 4 * 2)
    288         return_expr = "res" if keepdims else "res.item()"

File ~/miniforge3/envs/phd/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
    905 if not args:
    906     raise TypeError(f'{funcname} requires at least '
    907                     '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)

File ~/miniforge3/envs/phd/lib/python3.11/site-packages/pytensor/link/numba/dispatch/elemwise.py:66, in scalar_in_place_fn(op, idx, res, arr)
     51 @singledispatch
     52 def scalar_in_place_fn(op: Op, idx: str, res: str, arr: str):
     53     """Return code for an in-place update on an array using a binary scalar :class:`Op`.
     54 
     55     Parameters
   (...)
     64         The symbol name for the second input.
     65     """
---> 66     raise NotImplementedError()

NotImplementedError: 

What other PyTensor issue or areas would be most related to this?

Package versions

  • PyMC v5.8.0
  • PyTensor v2.15.0
  • nutpie v0.8.1 but with #61 manually incorporated

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions