Skip to content

Commit 55d915c

Browse files
authored
Make logprob inference for binary ops independent of order of inputs (#6682)
1 parent a59c9cd commit 55d915c

File tree

2 files changed

+60
-23
lines changed

2 files changed

+60
-23
lines changed

pymc/logprob/binary.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from pytensor.graph.fg import FunctionGraph
2121
from pytensor.graph.rewriting.basic import node_rewriter
2222
from pytensor.scalar.basic import GE, GT, LE, LT
23+
from pytensor.tensor import TensorVariable
2324
from pytensor.tensor.math import ge, gt, le, lt
2425

2526
from pymc.logprob.abstract import (
@@ -50,26 +51,49 @@ def find_measurable_comparisons(
5051
if isinstance(node.op, MeasurableComparison):
5152
return None # pragma: no cover
5253

53-
(compared_var,) = node.outputs
54-
base_var, const = node.inputs
54+
measurable_inputs = [
55+
(inp, idx)
56+
for idx, inp in enumerate(node.inputs)
57+
if inp.owner
58+
and isinstance(inp.owner.op, MeasurableVariable)
59+
and inp not in rv_map_feature.rv_values
60+
]
5561

56-
if not (
57-
base_var.owner
58-
and isinstance(base_var.owner.op, MeasurableVariable)
59-
and base_var not in rv_map_feature.rv_values
60-
):
62+
if len(measurable_inputs) != 1:
6163
return None
6264

65+
# Make the measurable base_var always be the first input to the MeasurableComparison node
66+
base_var: TensorVariable = measurable_inputs[0][0]
67+
68+
# Check that the other input is not potentially measurable, in which case this rewrite
69+
# would be invalid
70+
const = tuple(inp for inp in node.inputs if inp is not base_var)
71+
6372
# check for potential measurability of const
64-
if not check_potential_measurability((const,), rv_map_feature):
73+
if not check_potential_measurability(const, rv_map_feature):
6574
return None
6675

76+
const = const[0]
77+
6778
# Make base_var unmeasurable
6879
unmeasurable_base_var = ignore_logprob(base_var)
6980

70-
compared_op = MeasurableComparison(node.op.scalar_op)
81+
node_scalar_op = node.op.scalar_op
82+
83+
# Change the Op if the base_var is the second input in node.inputs. e.g. pt.lt(const, dist) -> pt.gt(dist, const)
84+
if measurable_inputs[0][1] == 1:
85+
if isinstance(node_scalar_op, LT):
86+
node_scalar_op = GT()
87+
elif isinstance(node_scalar_op, GT):
88+
node_scalar_op = LT()
89+
elif isinstance(node_scalar_op, GE):
90+
node_scalar_op = LE()
91+
elif isinstance(node_scalar_op, LE):
92+
node_scalar_op = GE()
93+
94+
compared_op = MeasurableComparison(node_scalar_op)
7195
compared_rv = compared_op.make_node(unmeasurable_base_var, const).default_output()
72-
compared_rv.name = compared_var.name
96+
compared_rv.name = node.outputs[0].name
7397
return [compared_rv]
7498

7599

tests/logprob/test_binary.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,17 @@
2525

2626

2727
@pytest.mark.parametrize(
28-
"comparison_op, exp_logp_true, exp_logp_false",
28+
"comparison_op, exp_logp_true, exp_logp_false, inputs",
2929
[
30-
((pt.lt, pt.le), "logcdf", "logsf"),
31-
((pt.gt, pt.ge), "logsf", "logcdf"),
30+
((pt.lt, pt.le), "logcdf", "logsf", (pt.random.normal(0, 1), 0.5)),
31+
((pt.gt, pt.ge), "logsf", "logcdf", (pt.random.normal(0, 1), 0.5)),
32+
((pt.lt, pt.le), "logsf", "logcdf", (0.5, pt.random.normal(0, 1))),
33+
((pt.gt, pt.ge), "logcdf", "logsf", (0.5, pt.random.normal(0, 1))),
3234
],
3335
)
34-
def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false):
35-
x_rv = pt.random.normal(0, 1)
36+
def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false, inputs):
3637
for op in comparison_op:
37-
comp_x_rv = op(x_rv, 0.5)
38+
comp_x_rv = op(*inputs)
3839

3940
comp_x_vv = comp_x_rv.clone()
4041

@@ -49,33 +50,45 @@ def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false):
4950

5051

5152
@pytest.mark.parametrize(
52-
"comparison_op, exp_logp_true, exp_logp_false",
53+
"comparison_op, exp_logp_true, exp_logp_false, inputs",
5354
[
5455
(
5556
pt.lt,
5657
lambda x: st.poisson(2).logcdf(x - 1),
5758
lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)),
59+
(pt.random.poisson(2), 3),
5860
),
5961
(
6062
pt.ge,
6163
lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)),
6264
lambda x: st.poisson(2).logcdf(x - 1),
65+
(pt.random.poisson(2), 3),
6366
),
67+
(pt.gt, st.poisson(2).logsf, st.poisson(2).logcdf, (pt.random.poisson(2), 3)),
68+
(pt.le, st.poisson(2).logcdf, st.poisson(2).logsf, (pt.random.poisson(2), 3)),
6469
(
65-
pt.gt,
70+
pt.lt,
6671
st.poisson(2).logsf,
6772
st.poisson(2).logcdf,
73+
(3, pt.random.poisson(2)),
74+
),
75+
(pt.ge, st.poisson(2).logcdf, st.poisson(2).logsf, (3, pt.random.poisson(2))),
76+
(
77+
pt.gt,
78+
lambda x: st.poisson(2).logcdf(x - 1),
79+
lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)),
80+
(3, pt.random.poisson(2)),
6881
),
6982
(
7083
pt.le,
71-
st.poisson(2).logcdf,
72-
st.poisson(2).logsf,
84+
lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)),
85+
lambda x: st.poisson(2).logcdf(x - 1),
86+
(3, pt.random.poisson(2)),
7387
),
7488
],
7589
)
76-
def test_discrete_rv_comparison(comparison_op, exp_logp_true, exp_logp_false):
77-
x_rv = pt.random.poisson(2)
78-
cens_x_rv = comparison_op(x_rv, 3)
90+
def test_discrete_rv_comparison(inputs, comparison_op, exp_logp_true, exp_logp_false):
91+
cens_x_rv = comparison_op(*inputs)
7992

8093
cens_x_vv = cens_x_rv.clone()
8194

0 commit comments

Comments
 (0)