Closed
Description
Describe the issue:
When trying to work out model logp vectorization, stucked with Dirichlet being broken under. It's most likely the pytensor issue, since the source of the error is in rewrites.
Reproduceable code example:
with pm.Model() as test_model:
v = pm.Dirichlet("v", [1., 1., 1.], shape=3)
test_lp = test_model.logp()
test_values_vec = {
v: v.type.clone(shape=(None, *v.type.shape))(name=f"{v.name}[]")
for v in test_model.value_vars
}
test_vectorized_lp = pytensor.graph.vectorize_graph([test_lp], test_values_vec)
test_f = pm.compile_pymc(
# works with mode="FAST_COMPILE"
list(test_values_vec.values()), test_vectorized_lp
)
test_f(np.random.randn(10, 2))
Error message:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
File ~/dev/pymc/.venv/lib/python3.11/site-packages/pytensor/compile/function/types.py:959, in Function.__call__(self, *args, **kwargs)
957 try:
958 outputs = (
--> 959 self.vm()
960 if output_subset is None
961 else self.vm(output_subset=output_subset)
962 )
963 except Exception:
ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 0, the array at index 0 has size 10 and the array at index 1 has size 1
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
Cell In[307], line 1
----> 1 test_f(np.random.randn(10, 2))
File ~/dev/pymc/.venv/lib/python3.11/site-packages/pytensor/compile/function/types.py:972, in Function.__call__(self, *args, **kwargs)
970 if hasattr(self.vm, "thunks"):
971 thunk = self.vm.thunks[self.vm.position_of_error]
--> 972 raise_with_op(
973 self.maker.fgraph,
974 node=self.vm.nodes[self.vm.position_of_error],
975 thunk=thunk,
976 storage_map=getattr(self.vm, "storage_map", None),
977 )
978 else:
979 # old-style linkers raise their own exceptions
980 raise
File ~/dev/pymc/.venv/lib/python3.11/site-packages/pytensor/link/utils.py:528, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
523 warnings.warn(
524 f"{exc_type} error does not allow us to add an extra error message"
525 )
526 # Some exception need extra parameter in inputs. So forget the
527 # extra long error message in that case.
--> 528 raise exc_value.with_traceback(exc_trace)
File ~/dev/pymc/.venv/lib/python3.11/site-packages/pytensor/compile/function/types.py:959, in Function.__call__(self, *args, **kwargs)
956 t0_fn = time.perf_counter()
957 try:
958 outputs = (
--> 959 self.vm()
960 if output_subset is None
961 else self.vm(output_subset=output_subset)
962 )
963 except Exception:
964 restore_defaults()
ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 0, the array at index 0 has size 10 and the array at index 1 has size 1
Apply node that caused the error: Join(1, Add.0, [[0.]])
Toposort index: 5
Inputs types: [TensorType(int8, shape=()), TensorType(float64, shape=(None, 2)), TensorType(float64, shape=(1, 1))]
Inputs shapes: [(), (10, 2), (1, 1)]
Inputs strides: [(), (16, 8), (8, 8)]
Inputs values: [array(1, dtype=int8), 'not shown', array([[0.]])]
Outputs clients: [[Max{axis=1}(Join.0), Composite{switch(i2, i3, exp((i0 - i1)))}(Join.0, ExpandDims{axis=1}.0, Isinf.0, Exp.0)]]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
File "/home/ferres/dev/pymc/.venv/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner
coro.send(None)
File "/home/ferres/dev/pymc/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File "/home/ferres/dev/pymc/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes
if await self.run_code(code, result, async_=asy):
File "/home/ferres/dev/pymc/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/tmp/ipykernel_921353/1334936598.py", line 8, in <module>
test_vectorized_lp = pytensor.graph.vectorize_graph([test_lp], test_values_vec)
File "/home/ferres/dev/pymc/.venv/lib/python3.11/site-packages/pytensor/graph/replace.py", line 301, in vectorize_graph
vect_node = vectorize_node(node, *vect_inputs)
File "/home/ferres/dev/pymc/.venv/lib/python3.11/site-packages/pytensor/graph/replace.py", line 217, in vectorize_node
return _vectorize_node(op, node, *batched_inputs)
File "/home/ferres/dev/pymc/.venv/lib/python3.11/functools.py", line 909, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
Graph of the function
Composite{(switch(or(i3, i4), -inf, 0.6931471805599453) + ((1.0986122886681098 + (3.0 * i2)) - (3.0 * (i1 + log(i0)))))} [id A] 20
├─ Sum{axis=1} [id B] 17
│ └─ Composite{switch(i2, i3, exp((i0 - i1)))} [id C] 14
│ ├─ Join [id D] 5
│ │ ├─ 1 [id E]
│ │ ├─ Add [id F] 3
│ │ │ ├─ v_simplex__[] [id G]
│ │ │ └─ ExpandDims{axis=1} [id H] 1
│ │ │ └─ Sum{axis=1} [id I] 0
│ │ │ └─ v_simplex__[] [id G]
│ │ └─ [[0.]] [id J]
│ ├─ ExpandDims{axis=1} [id K] 9
│ │ └─ Max{axis=1} [id L] 7
│ │ └─ Join [id D] 5
│ │ └─ ···
│ ├─ Isinf [id M] 12
│ │ └─ ExpandDims{axis=1} [id K] 9
│ │ └─ ···
│ └─ Exp [id N] 11
│ └─ ExpandDims{axis=1} [id K] 9
│ └─ ···
├─ Max{axis=1} [id L] 7
│ └─ ···
├─ Sum{axis=1} [id I] 0
│ └─ ···
├─ Any{axis=1} [id O] 19
│ └─ Lt [id P] 16
│ ├─ Softmax{axis=1} [id Q] 13
│ │ └─ Sub [id R] 10
│ │ ├─ Join [id S] 4
│ │ │ ├─ 1 [id E]
│ │ │ ├─ v_simplex__[] [id G]
│ │ │ └─ Neg [id T] 2
│ │ │ └─ ExpandDims{axis=1} [id H] 1
│ │ │ └─ ···
│ │ └─ ExpandDims{axis=1} [id U] 8
│ │ └─ Max{axis=1} [id V] 6
│ │ └─ Join [id S] 4
│ │ └─ ···
│ └─ [[0]] [id W]
└─ Any{axis=1} [id X] 18
└─ Gt [id Y] 15
├─ Softmax{axis=1} [id Q] 13
│ └─ ···
└─ [[1]] [id Z]
Inner graphs:
Composite{(switch(or(i3, i4), -inf, 0.6931471805599453) + ((1.0986122886681098 + (3.0 * i2)) - (3.0 * (i1 + log(i0)))))} [id A]
← add [id BA] 'o0'
├─ Switch [id BB]
│ ├─ OR [id BC]
│ │ ├─ i3 [id BD]
│ │ └─ i4 [id BE]
│ ├─ -inf [id BF]
│ └─ 0.6931471805599453 [id BG]
└─ sub [id BH]
├─ add [id BI]
│ ├─ 1.0986122886681098 [id BJ]
│ └─ mul [id BK]
│ ├─ t1{3.0} [id BL]
│ └─ i2 [id BM]
└─ mul [id BN]
├─ t1{3.0} [id BL]
└─ add [id BO]
├─ i1 [id BP]
└─ log [id BQ]
└─ i0 [id BR]
Composite{switch(i2, i3, exp((i0 - i1)))} [id C]
← Switch [id BS] 'o0'
├─ i2 [id BT]
├─ i3 [id BU]
└─ exp [id BV]
└─ sub [id BW]
├─ i0 [id BX]
└─ i1 [id BY]
PyMC version information:
'5.16.2+3.gc8b22df21'
Context for the issue:
This issue is related to Empirical Bayes experiments to automatically set hyperpriors