Closed
Description
Description
I am messing with PyMC LKJ priors for the variance-covariance matrix of multivariate normals. It would be nice to have nutpie be able to sample this as the current PyMC sampler seems slow, at least for my simulations. This is likely to be a duplicate issue...
import pymc as pm
import numpy as np
import nutpie
import pytensor
import pytensor.tensor as pt
mu = np.array([-3, -2, -1, 0, 1, 2, 3])
n = len(mu)
with pm.Model() as model:
sd_dist = pm.Exponential.dist(1.0, size=n)
chol, corr, sigmas = pm.LKJCholeskyCov(
"chol_cov",
eta=4,
n=n,
sd_dist=sd_dist,
)
vals = pm.MvNormal("vals", mu=mu, chol=chol, size=13)
compiled_model = nutpie.compile_pymc_model(model)
trace_pymc = nutpie.sample(compiled_model)
yields:
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
Cell In[3], line 12
3 chol, corr, sigmas = pm.LKJCholeskyCov(
4 "chol_cov",
5 eta=4,
6 n=n,
7 sd_dist=sd_dist,
8 )
10 vals = pm.MvNormal("vals", mu=mu, chol=chol, size=13)
---> 12 compiled_model = nutpie.compile_pymc_model(model)
13 trace_pymc = nutpie.sample(compiled_model)
File ~/miniforge3/envs/phd/lib/python3.11/site-packages/nutpie/compile_pymc.py:171, in compile_pymc_model(model, **kwargs)
147 def compile_pymc_model(model: pm.Model, **kwargs) -> CompiledPyMCModel:
148 """Compile necessary functions for sampling a pymc model.
149
150 Parameters
(...)
159
160 """
162 (
163 n_dim,
164 n_expanded,
165 logp_fn_pt,
166 logp_fn,
167 expand_fn_pt,
168 expand_fn,
169 shared_expand,
170 shape_info,
--> 171 ) = _make_functions(model)
173 shared_data = {val.name: val.get_value().copy() for val in logp_fn_pt.get_shared()}
174 for val in shared_data.values():
File ~/miniforge3/envs/phd/lib/python3.11/site-packages/nutpie/compile_pymc.py:312, in _make_functions(model)
309 (logp, grad) = pytensor.graph_replace([logp, grad], replacements)
311 # We should avoid compiling the function, and optimize only
--> 312 logp_fn_pt = pytensor.compile.function.function(
313 (joined,), (logp, grad), mode=pytensor.compile.NUMBA
314 )
316 logp_fn = logp_fn_pt.vm.jit_fn
318 # Make function that computes remaining variables for the trace
File ~/miniforge3/envs/phd/lib/python3.11/site-packages/pytensor/compile/function/__init__.py:315, in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input)
309 fn = orig_function(
310 inputs, outputs, mode=mode, accept_inplace=accept_inplace, name=name
311 )
312 else:
313 # note: pfunc will also call orig_function -- orig_function is
314 # a choke point that all compilation must pass through
--> 315 fn = pfunc(
316 params=inputs,
317 outputs=outputs,
318 mode=mode,
319 updates=updates,
320 givens=givens,
321 no_default_updates=no_default_updates,
322 accept_inplace=accept_inplace,
323 name=name,
324 rebuild_strict=rebuild_strict,
325 allow_input_downcast=allow_input_downcast,
326 on_unused_input=on_unused_input,
327 profile=profile,
328 output_keys=output_keys,
329 )
330 return fn
File ~/miniforge3/envs/phd/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:468, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph)
454 profile = ProfileStats(message=profile)
456 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
457 params,
458 outputs,
(...)
465 fgraph=fgraph,
466 )
--> 468 return orig_function(
469 inputs,
470 cloned_outputs,
471 mode,
472 accept_inplace=accept_inplace,
473 name=name,
474 profile=profile,
475 on_unused_input=on_unused_input,
476 output_keys=output_keys,
477 fgraph=fgraph,
478 )
File ~/miniforge3/envs/phd/lib/python3.11/site-packages/pytensor/compile/function/types.py:1756, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph)
1744 m = Maker(
1745 inputs,
1746 outputs,
(...)
1753 fgraph=fgraph,
1754 )
1755 with config.change_flags(compute_test_value="off"):
-> 1756 fn = m.create(defaults)
1757 finally:
1758 t2 = time.perf_counter()
File ~/miniforge3/envs/phd/lib/python3.11/site-packages/pytensor/compile/function/types.py:1649, in FunctionMaker.create(self, input_storage, storage_map)
1646 start_import_time = pytensor.link.c.cmodule.import_time
1648 with config.change_flags(traceback__limit=config.traceback__compile_limit):
-> 1649 _fn, _i, _o = self.linker.make_thunk(
1650 input_storage=input_storage_lists, storage_map=storage_map
1651 )
1653 end_linker = time.perf_counter()
1655 linker_time = end_linker - start_linker
File ~/miniforge3/envs/phd/lib/python3.11/site-packages/pytensor/link/basic.py:254, in LocalLinker.make_thunk(self, input_storage, output_storage, storage_map, **kwargs)
247 def make_thunk(
248 self,
249 input_storage: Optional["InputStorageType"] = None,
(...)
252 **kwargs,
253 ) -> Tuple["BasicThunkType", "InputStorageType", "OutputStorageType"]:
--> 254 return self.make_all(
255 input_storage=input_storage,
256 output_storage=output_storage,
257 storage_map=storage_map,
258 )[:3]
File ~/miniforge3/envs/phd/lib/python3.11/site-packages/pytensor/link/basic.py:697, in JITLinker.make_all(self, input_storage, output_storage, storage_map)
694 for k in storage_map:
695 compute_map[k] = [k.owner is None]
--> 697 thunks, nodes, jit_fn = self.create_jitable_thunk(
698 compute_map, nodes, input_storage, output_storage, storage_map
699 )
701 computed, last_user = gc_helper(nodes)
703 if self.allow_gc:
File ~/miniforge3/envs/phd/lib/python3.11/site-packages/pytensor/link/basic.py:647, in JITLinker.create_jitable_thunk(self, compute_map, order, input_storage, output_storage, storage_map)
644 # This is a bit hackish, but we only return one of the output nodes
645 output_nodes = [o.owner for o in self.fgraph.outputs if o.owner is not None][:1]
--> 647 converted_fgraph = self.fgraph_convert(
648 self.fgraph,
649 order=order,
650 input_storage=input_storage,
651 output_storage=output_storage,
652 storage_map=storage_map,
653 )
655 thunk_inputs = self.create_thunk_inputs(storage_map)
657 thunks = []
File ~/miniforge3/envs/phd/lib/python3.11/site-packages/pytensor/link/numba/linker.py:27, in NumbaLinker.fgraph_convert(self, fgraph, **kwargs)
24 def fgraph_convert(self, fgraph, **kwargs):
25 from pytensor.link.numba.dispatch import numba_funcify
---> 27 return numba_funcify(fgraph, **kwargs)
File ~/miniforge3/envs/phd/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
905 if not args:
906 raise TypeError(f'{funcname} requires at least '
907 '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)
File ~/miniforge3/envs/phd/lib/python3.11/site-packages/pytensor/link/numba/dispatch/basic.py:459, in numba_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
452 @numba_funcify.register(FunctionGraph)
453 def numba_funcify_FunctionGraph(
454 fgraph,
(...)
457 **kwargs,
458 ):
--> 459 return fgraph_to_python(
460 fgraph,
461 numba_funcify,
462 type_conversion_fn=numba_typify,
463 fgraph_name=fgraph_name,
464 **kwargs,
465 )
File ~/miniforge3/envs/phd/lib/python3.11/site-packages/pytensor/link/utils.py:738, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
736 body_assigns = []
737 for node in order:
--> 738 compiled_func = op_conversion_fn(
739 node.op, node=node, storage_map=storage_map, **kwargs
740 )
742 # Create a local alias with a unique name
743 local_compiled_func_name = unique_name(compiled_func)
File ~/miniforge3/envs/phd/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
905 if not args:
906 raise TypeError(f'{funcname} requires at least '
907 '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)
File ~/miniforge3/envs/phd/lib/python3.11/site-packages/pytensor/link/numba/dispatch/elemwise.py:755, in numba_funcify_CAReduce(op, node, **kwargs)
753 input_name = get_name_for_object(node.inputs[0])
754 ndim = node.inputs[0].ndim
--> 755 careduce_py_fn = create_multiaxis_reducer(
756 op.scalar_op,
757 scalar_op_identity,
758 axes,
759 ndim,
760 np.dtype(node.outputs[0].type.dtype),
761 input_name=input_name,
762 )
764 careduce_fn = jit_compile_reducer(node, careduce_py_fn, reduce_to_scalar=False)
765 return careduce_fn
File ~/miniforge3/envs/phd/lib/python3.11/site-packages/pytensor/link/numba/dispatch/elemwise.py:358, in create_multiaxis_reducer(scalar_op, identity, axes, ndim, dtype, input_name, return_scalar)
320 r"""Construct a function that reduces multiple axes.
321
322 The functions generated by this function take the following form:
(...)
355
356 """
357 if len(axes) == 1:
--> 358 return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype)
360 axes = normalize_axis_tuple(axes, ndim)
362 careduce_fn_name = f"careduce_{scalar_op}"
File ~/miniforge3/envs/phd/lib/python3.11/site-packages/pytensor/link/numba/dispatch/elemwise.py:285, in create_axis_reducer(scalar_op, identity, axis, ndim, dtype, keepdims, return_scalar)
269 reduce_elemwise_def_src = f"""
270 def {reduce_elemwise_fn_name}(x):
271
(...)
282 return {return_expr}
283 """
284 else:
--> 285 inplace_update_statement = scalar_in_place_fn(scalar_op, "0", "res", "x[i]")
286 inplace_update_statement = indent(inplace_update_statement, " " * 4 * 2)
288 return_expr = "res" if keepdims else "res.item()"
File ~/miniforge3/envs/phd/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
905 if not args:
906 raise TypeError(f'{funcname} requires at least '
907 '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)
File ~/miniforge3/envs/phd/lib/python3.11/site-packages/pytensor/link/numba/dispatch/elemwise.py:66, in scalar_in_place_fn(op, idx, res, arr)
51 @singledispatch
52 def scalar_in_place_fn(op: Op, idx: str, res: str, arr: str):
53 """Return code for an in-place update on an array using a binary scalar :class:`Op`.
54
55 Parameters
(...)
64 The symbol name for the second input.
65 """
---> 66 raise NotImplementedError()
NotImplementedError:
What other PyTensor issue or areas would be most related to this?
Package versions
- PyMC v5.8.0
- PyTensor v2.15.0
- nutpie v0.8.1 but with #61 manually incorporated
Metadata
Metadata
Assignees
Labels
No labels