Skip to content

BUG: Dirichlet is broken under graph vectorization  #7414

Closed
@ferrine

Description

@ferrine

Describe the issue:

When trying to work out model logp vectorization, stucked with Dirichlet being broken under. It's most likely the pytensor issue, since the source of the error is in rewrites.

Reproduceable code example:

with pm.Model() as test_model:
    v = pm.Dirichlet("v", [1., 1., 1.], shape=3)
test_lp = test_model.logp()
test_values_vec = {
    v: v.type.clone(shape=(None, *v.type.shape))(name=f"{v.name}[]")
    for v in test_model.value_vars
}
test_vectorized_lp = pytensor.graph.vectorize_graph([test_lp], test_values_vec)
test_f = pm.compile_pymc(
    # works with mode="FAST_COMPILE"
    list(test_values_vec.values()), test_vectorized_lp
)
test_f(np.random.randn(10, 2))

Error message:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File ~/dev/pymc/.venv/lib/python3.11/site-packages/pytensor/compile/function/types.py:959, in Function.__call__(self, *args, **kwargs)
    957 try:
    958     outputs = (
--> 959         self.vm()
    960         if output_subset is None
    961         else self.vm(output_subset=output_subset)
    962     )
    963 except Exception:

ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 0, the array at index 0 has size 10 and the array at index 1 has size 1

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
Cell In[307], line 1
----> 1 test_f(np.random.randn(10, 2))

File ~/dev/pymc/.venv/lib/python3.11/site-packages/pytensor/compile/function/types.py:972, in Function.__call__(self, *args, **kwargs)
    970     if hasattr(self.vm, "thunks"):
    971         thunk = self.vm.thunks[self.vm.position_of_error]
--> 972     raise_with_op(
    973         self.maker.fgraph,
    974         node=self.vm.nodes[self.vm.position_of_error],
    975         thunk=thunk,
    976         storage_map=getattr(self.vm, "storage_map", None),
    977     )
    978 else:
    979     # old-style linkers raise their own exceptions
    980     raise

File ~/dev/pymc/.venv/lib/python3.11/site-packages/pytensor/link/utils.py:528, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    523     warnings.warn(
    524         f"{exc_type} error does not allow us to add an extra error message"
    525     )
    526     # Some exception need extra parameter in inputs. So forget the
    527     # extra long error message in that case.
--> 528 raise exc_value.with_traceback(exc_trace)

File ~/dev/pymc/.venv/lib/python3.11/site-packages/pytensor/compile/function/types.py:959, in Function.__call__(self, *args, **kwargs)
    956 t0_fn = time.perf_counter()
    957 try:
    958     outputs = (
--> 959         self.vm()
    960         if output_subset is None
    961         else self.vm(output_subset=output_subset)
    962     )
    963 except Exception:
    964     restore_defaults()

ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 0, the array at index 0 has size 10 and the array at index 1 has size 1
Apply node that caused the error: Join(1, Add.0, [[0.]])
Toposort index: 5
Inputs types: [TensorType(int8, shape=()), TensorType(float64, shape=(None, 2)), TensorType(float64, shape=(1, 1))]
Inputs shapes: [(), (10, 2), (1, 1)]
Inputs strides: [(), (16, 8), (8, 8)]
Inputs values: [array(1, dtype=int8), 'not shown', array([[0.]])]
Outputs clients: [[Max{axis=1}(Join.0), Composite{switch(i2, i3, exp((i0 - i1)))}(Join.0, ExpandDims{axis=1}.0, Isinf.0, Exp.0)]]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/home/ferres/dev/pymc/.venv/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner
    coro.send(None)
  File "/home/ferres/dev/pymc/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/home/ferres/dev/pymc/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/home/ferres/dev/pymc/.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_921353/1334936598.py", line 8, in <module>
    test_vectorized_lp = pytensor.graph.vectorize_graph([test_lp], test_values_vec)
  File "/home/ferres/dev/pymc/.venv/lib/python3.11/site-packages/pytensor/graph/replace.py", line 301, in vectorize_graph
    vect_node = vectorize_node(node, *vect_inputs)
  File "/home/ferres/dev/pymc/.venv/lib/python3.11/site-packages/pytensor/graph/replace.py", line 217, in vectorize_node
    return _vectorize_node(op, node, *batched_inputs)
  File "/home/ferres/dev/pymc/.venv/lib/python3.11/functools.py", line 909, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

Graph of the function

Composite{(switch(or(i3, i4), -inf, 0.6931471805599453) + ((1.0986122886681098 + (3.0 * i2)) - (3.0 * (i1 + log(i0)))))} [id A] 20
 ├─ Sum{axis=1} [id B] 17
 │  └─ Composite{switch(i2, i3, exp((i0 - i1)))} [id C] 14
 │     ├─ Join [id D] 5
 │     │  ├─ 1 [id E]
 │     │  ├─ Add [id F] 3
 │     │  │  ├─ v_simplex__[] [id G]
 │     │  │  └─ ExpandDims{axis=1} [id H] 1
 │     │  │     └─ Sum{axis=1} [id I] 0
 │     │  │        └─ v_simplex__[] [id G]
 │     │  └─ [[0.]] [id J]
 │     ├─ ExpandDims{axis=1} [id K] 9
 │     │  └─ Max{axis=1} [id L] 7
 │     │     └─ Join [id D] 5
 │     │        └─ ···
 │     ├─ Isinf [id M] 12
 │     │  └─ ExpandDims{axis=1} [id K] 9
 │     │     └─ ···
 │     └─ Exp [id N] 11
 │        └─ ExpandDims{axis=1} [id K] 9
 │           └─ ···
 ├─ Max{axis=1} [id L] 7
 │  └─ ···
 ├─ Sum{axis=1} [id I] 0
 │  └─ ···
 ├─ Any{axis=1} [id O] 19
 │  └─ Lt [id P] 16
 │     ├─ Softmax{axis=1} [id Q] 13
 │     │  └─ Sub [id R] 10
 │     │     ├─ Join [id S] 4
 │     │     │  ├─ 1 [id E]
 │     │     │  ├─ v_simplex__[] [id G]
 │     │     │  └─ Neg [id T] 2
 │     │     │     └─ ExpandDims{axis=1} [id H] 1
 │     │     │        └─ ···
 │     │     └─ ExpandDims{axis=1} [id U] 8
 │     │        └─ Max{axis=1} [id V] 6
 │     │           └─ Join [id S] 4
 │     │              └─ ···
 │     └─ [[0]] [id W]
 └─ Any{axis=1} [id X] 18
    └─ Gt [id Y] 15
       ├─ Softmax{axis=1} [id Q] 13
       │  └─ ···
       └─ [[1]] [id Z]

Inner graphs:

Composite{(switch(or(i3, i4), -inf, 0.6931471805599453) + ((1.0986122886681098 + (3.0 * i2)) - (3.0 * (i1 + log(i0)))))} [id A]
 ← add [id BA] 'o0'
    ├─ Switch [id BB]
    │  ├─ OR [id BC]
    │  │  ├─ i3 [id BD]
    │  │  └─ i4 [id BE]
    │  ├─ -inf [id BF]
    │  └─ 0.6931471805599453 [id BG]
    └─ sub [id BH]
       ├─ add [id BI]
       │  ├─ 1.0986122886681098 [id BJ]
       │  └─ mul [id BK]
       │     ├─ t1{3.0} [id BL]
       │     └─ i2 [id BM]
       └─ mul [id BN]
          ├─ t1{3.0} [id BL]
          └─ add [id BO]
             ├─ i1 [id BP]
             └─ log [id BQ]
                └─ i0 [id BR]

Composite{switch(i2, i3, exp((i0 - i1)))} [id C]
 ← Switch [id BS] 'o0'
    ├─ i2 [id BT]
    ├─ i3 [id BU]
    └─ exp [id BV]
       └─ sub [id BW]
          ├─ i0 [id BX]
          └─ i1 [id BY]

PyMC version information:

'5.16.2+3.gc8b22df21'

Context for the issue:

This issue is related to Empirical Bayes experiments to automatically set hyperpriors

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions