diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index ce6a11d208..c8fac713ce 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -292,6 +292,7 @@ def find_measurable_index_mixture(fgraph, node): indices.dtype.startswith("int") and sum(1 - b for b in indices.type.broadcastable) > 0 for indices in mixing_indices if not isinstance(indices, SliceConstant) + if not isinstance(indices, type(NoneConst)) ): return None diff --git a/tests/logprob/test_mixture.py b/tests/logprob/test_mixture.py index ffb2bf07c0..7b4ec421f1 100644 --- a/tests/logprob/test_mixture.py +++ b/tests/logprob/test_mixture.py @@ -1156,3 +1156,26 @@ def test_nested_ifelse(): np.testing.assert_almost_equal(mix_logp_fn(0, test_value), sp.norm.logpdf(test_value, -5, 1)) np.testing.assert_almost_equal(mix_logp_fn(1, test_value), sp.norm.logpdf(test_value, 0, 1)) np.testing.assert_almost_equal(mix_logp_fn(2, test_value), sp.norm.logpdf(test_value, 5, 1)) + + +def test_advanced_subtensor_none_and_integer(): + a = pt.random.normal(0, 1, size=(10,), name="a") + inds = np.array([0, 1, 2, 3], dtype="int32") + b = a[None, inds] + + b_val = b.type() + b_val.name = "b_val" + a_val = a.type() + a_val.name = "a_val" + + try: + # b is tested by being rewritten with logp, it should have a runtime error + logp_dict = conditional_logp({b: b_val, a: a_val}) + + # A runtime error means that a value of None was assigned to b, instead of the internal attribution error + except RuntimeError as e: + if "AttributeError" in str(e): + assert False, f"Rewrite failed with original bug: {e}" + + except Exception as e: + assert False, f"Rewrite raised an unexpected error: {e}" \ No newline at end of file