Open
Description
Describe the issue:
I just ran into a logprob rewrite error with an AdvancedSubTensor
op that mixed None
entries and int32
indices together. This wasn't actually a mixture model but logprob found the op and tried to apply its rewrite rules and raised an error instead of just failing silently. The problem seems to be from this line that doesn't include a guard against a None
constant as well as a slice
constant.
Reproduceable code example:
import numpy as np
import pymc as pm
obs = np.random.default_rng().normal(size=(7, 4))
with pm.Model():
inds = np.arange(obs.shape[1])
a = pm.Normal("a", shape=10)
b = pm.Deterministic("b", a[None, inds])
c = pm.Normal("c", mu=b, sigma=1, observed=obs)
pm.sample()
Error message:
ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: find_measurable_index_mixture
ERROR (pytensor.graph.rewriting.basic): node: AdvancedSubtensor(a, NoneConst{None}, [0 1 2 3])
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
File "pytensor/graph/rewriting/basic.py", line 1913, in process_node
replacements = node_rewriter.transform(fgraph, node)
File "pytensor/graph/rewriting/basic.py", line 1085, in transform
return self.fn(fgraph, node)
File "pymc/logprob/mixture.py", line 291, in find_measurable_index_mixture
if any(
File "pymc/logprob/mixture.py", line 292, in <genexpr>
indices.dtype.startswith("int") and sum(1 - b for b in indices.type.broadcastable) > 0
AttributeError: 'Constant' object has no attribute 'dtype'. Did you mean: 'type'?
but sampling works fine because the rewrite was actually supposed to fail and return None
.
PyMC version information:
Github main
Context for the issue:
This doesn't really affect anything. It just confuses regular users that see the error traceback from rewriting and get alarmed. It would be more elegant to handle this extra indexer type just like with slice constants.