Skip to content

Commit f374411

Browse files
committed
Fix __bool__ usages with PyTensor variables
1 parent 0035ab7 commit f374411

File tree

7 files changed

+11
-11
lines changed

7 files changed

+11
-11
lines changed

pymc/distributions/multivariate.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2057,14 +2057,12 @@ class KroneckerNormal(Continuous):
20572057
rv_op = KroneckerNormalRV.rv_op
20582058

20592059
@classmethod
2060-
def dist(cls, mu, covs=None, chols=None, evds=None, sigma=None, *args, **kwargs):
2060+
def dist(cls, mu, covs=None, chols=None, evds=None, sigma=0.0, *args, **kwargs):
20612061
if len([i for i in [covs, chols, evds] if i is not None]) != 1:
20622062
raise ValueError(
20632063
"Incompatible parameterization. Specify exactly one of covs, chols, or evds."
20642064
)
20652065

2066-
sigma = sigma if sigma else 0
2067-
20682066
if chols is not None:
20692067
covs = [chol.dot(chol.T) for chol in chols]
20702068
elif evds is not None:
@@ -2076,6 +2074,7 @@ def dist(cls, mu, covs=None, chols=None, evds=None, sigma=None, *args, **kwargs)
20762074
covs.append(cov_i)
20772075

20782076
mu = pt.as_tensor_variable(mu)
2077+
sigma = pt.as_tensor_variable(sigma)
20792078

20802079
return super().dist([mu, sigma, *covs], **kwargs)
20812080

pymc/logprob/transforms.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,9 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Apply) -> list[Varia
470470

471471
# Do not apply rewrite to discrete variables except for their addition and negation
472472
if measurable_input.type.dtype.startswith("int"):
473-
if not (find_negated_var(measurable_output) or isinstance(node.op.scalar_op, Add)):
473+
if not (
474+
find_negated_var(measurable_output) is not None or isinstance(node.op.scalar_op, Add)
475+
):
474476
return None
475477
# Do not allow rewrite if output is cast to a float, because we don't have meta-info on the type of the MeasurableVariable
476478
if not measurable_output.type.dtype.startswith("int"):

pymc/logprob/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,13 +186,12 @@ def expand_fn(var):
186186
return []
187187

188188
if any(
189-
ancestor_var
190-
for ancestor_var in walk(inputs, expand=expand_fn, bfs=False)
191-
if (
189+
(
192190
ancestor_var.owner
193191
and isinstance(ancestor_var.owner.op, MeasurableOp)
194192
and not isinstance(ancestor_var.owner.op, ValuedRV)
195193
)
194+
for ancestor_var in walk(inputs, expand=expand_fn, bfs=False)
196195
):
197196
return True
198197
return False

tests/distributions/test_custom.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ def test_symbolic_dist(self):
604604
def dist(size):
605605
return Truncated.dist(Beta.dist(1, 1, size=size), lower=0.1, upper=0.9)
606606

607-
assert CustomDist.dist(dist=dist)
607+
CustomDist.dist(dist=dist)
608608

609609
def test_nested_custom_dist(self):
610610
"""Test we can create CustomDist that creates another CustomDist"""

tests/distributions/test_multivariate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ def test_matrixnormal(self, n):
421421

422422
@pytest.mark.parametrize("n", [2, 3])
423423
@pytest.mark.parametrize("m", [3])
424-
@pytest.mark.parametrize("sigma", [None, 1])
424+
@pytest.mark.parametrize("sigma", [0, 1])
425425
def test_kroneckernormal(self, n, m, sigma):
426426
np.random.seed(5)
427427
N = n * m

tests/logprob/test_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def test_warn_random_found_probability_inference(func, scipy_func, test_value):
382382
with pytest.warns(
383383
UserWarning, match="RandomVariables {input} were found in the derived graph"
384384
):
385-
assert func(rv, 0.0)
385+
func(rv, 0.0)
386386

387387
res = func(rv, 0.0, warn_rvs=False)
388388
# This is the problem we are warning about, as now we can no longer identify the original rv in the graph

tests/logprob/test_order.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,4 +291,4 @@ def test_non_measurable_max_grad():
291291
joint_logp = pt.sum([term.sum() for term in logp_terms])
292292

293293
# Test that calling gradient does not raise a NotImplementedError
294-
assert pt.grad(joint_logp, x_vv)
294+
pt.grad(joint_logp, x_vv)

0 commit comments

Comments
 (0)