Skip to content

BUG: AdvancedSubTensor with None and integer indices raises a logprob error instead of silently failing #7762

Open
@lucianopaz

Description

@lucianopaz

Describe the issue:

I just ran into a logprob rewrite error with an AdvancedSubTensor op that mixed None entries and int32 indices together. This wasn't actually a mixture model but logprob found the op and tried to apply its rewrite rules and raised an error instead of just failing silently. The problem seems to be from this line that doesn't include a guard against a None constant as well as a slice constant.

Reproduceable code example:

import numpy as np
import pymc as pm


obs = np.random.default_rng().normal(size=(7, 4))
with pm.Model():
   inds = np.arange(obs.shape[1])
   a = pm.Normal("a", shape=10)
   b = pm.Deterministic("b", a[None, inds])
   c = pm.Normal("c", mu=b, sigma=1, observed=obs)
   pm.sample()

Error message:

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: find_measurable_index_mixture
ERROR (pytensor.graph.rewriting.basic): node: AdvancedSubtensor(a, NoneConst{None}, [0 1 2 3])
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "pytensor/graph/rewriting/basic.py", line 1913, in process_node
    replacements = node_rewriter.transform(fgraph, node)
  File "pytensor/graph/rewriting/basic.py", line 1085, in transform
    return self.fn(fgraph, node)
  File "pymc/logprob/mixture.py", line 291, in find_measurable_index_mixture
    if any(
  File "pymc/logprob/mixture.py", line 292, in <genexpr>
    indices.dtype.startswith("int") and sum(1 - b for b in indices.type.broadcastable) > 0
AttributeError: 'Constant' object has no attribute 'dtype'. Did you mean: 'type'?

but sampling works fine because the rewrite was actually supposed to fail and return None.

PyMC version information:

Github main

Context for the issue:

This doesn't really affect anything. It just confuses regular users that see the error traceback from rewriting and get alarmed. It would be more elegant to handle this extra indexer type just like with slice constants.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions