Skip to content

Commit ccb9cfb

Browse files
committed
Remove joint_logprob function from tests.logprob.utils
1 parent 1ed4475 commit ccb9cfb

File tree

9 files changed

+195
-174
lines changed

9 files changed

+195
-174
lines changed

tests/logprob/test_censoring.py

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@
4141
import scipy as sp
4242
import scipy.stats as st
4343

44+
from pymc import logp
4445
from pymc.logprob import factorized_joint_logprob
4546
from pymc.logprob.transforms import LogTransform, TransformValuesRewrite
4647
from pymc.testing import assert_no_rvs
47-
from tests.logprob.utils import joint_logprob
4848

4949

5050
@pytensor.config.change_flags(compute_test_value="raise")
@@ -55,10 +55,10 @@ def test_continuous_rv_clip():
5555
cens_x_vv = cens_x_rv.clone()
5656
cens_x_vv.tag.test_value = 0
5757

58-
logp = joint_logprob({cens_x_rv: cens_x_vv})
59-
assert_no_rvs(logp)
58+
logprob = pt.sum(logp(cens_x_rv, cens_x_vv))
59+
assert_no_rvs(logprob)
6060

61-
logp_fn = pytensor.function([cens_x_vv], logp)
61+
logp_fn = pytensor.function([cens_x_vv], logprob)
6262
ref_scipy = st.norm(0.5, 1)
6363

6464
assert logp_fn(-3) == -np.inf
@@ -75,10 +75,10 @@ def test_discrete_rv_clip():
7575

7676
cens_x_vv = cens_x_rv.clone()
7777

78-
logp = joint_logprob({cens_x_rv: cens_x_vv})
79-
assert_no_rvs(logp)
78+
logprob = pt.sum(logp(cens_x_rv, cens_x_vv))
79+
assert_no_rvs(logprob)
8080

81-
logp_fn = pytensor.function([cens_x_vv], logp)
81+
logp_fn = pytensor.function([cens_x_vv], logprob)
8282
ref_scipy = st.poisson(2)
8383

8484
assert logp_fn(0) == -np.inf
@@ -97,8 +97,8 @@ def test_one_sided_clip():
9797
lb_cens_x_vv = lb_cens_x_rv.clone()
9898
ub_cens_x_vv = ub_cens_x_rv.clone()
9999

100-
lb_logp = joint_logprob({lb_cens_x_rv: lb_cens_x_vv})
101-
ub_logp = joint_logprob({ub_cens_x_rv: ub_cens_x_vv})
100+
lb_logp = pt.sum(logp(lb_cens_x_rv, lb_cens_x_vv))
101+
ub_logp = pt.sum(logp(ub_cens_x_rv, ub_cens_x_vv))
102102
assert_no_rvs(lb_logp)
103103
assert_no_rvs(ub_logp)
104104

@@ -117,10 +117,10 @@ def test_useless_clip():
117117

118118
cens_x_vv = cens_x_rv.clone()
119119

120-
logp = joint_logprob({cens_x_rv: cens_x_vv}, sum=False)
121-
assert_no_rvs(logp)
120+
logprob = logp(cens_x_rv, cens_x_vv)
121+
assert_no_rvs(logprob)
122122

123-
logp_fn = pytensor.function([cens_x_vv], logp)
123+
logp_fn = pytensor.function([cens_x_vv], logprob)
124124
ref_scipy = st.norm(0.5, 1)
125125

126126
np.testing.assert_allclose(logp_fn([-2, 0, 2]), ref_scipy.logpdf([-2, 0, 2]))
@@ -133,10 +133,12 @@ def test_random_clip():
133133

134134
lb_vv = lb_rv.clone()
135135
cens_x_vv = cens_x_rv.clone()
136-
logp = joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv}, sum=False)
137-
assert_no_rvs(logp)
136+
logp = factorized_joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv})
137+
logp_combined = pt.add(*logp.values())
138138

139-
logp_fn = pytensor.function([lb_vv, cens_x_vv], logp)
139+
assert_no_rvs(logp_combined)
140+
141+
logp_fn = pytensor.function([lb_vv, cens_x_vv], logp_combined)
140142
res = logp_fn([0, -1], [-1, -1])
141143
assert res[0] == -np.inf
142144
assert res[1] != -np.inf
@@ -150,8 +152,10 @@ def test_broadcasted_clip_constant():
150152
lb_vv = lb_rv.clone()
151153
cens_x_vv = cens_x_rv.clone()
152154

153-
logp = joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv})
154-
assert_no_rvs(logp)
155+
logp = factorized_joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv})
156+
logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()])
157+
158+
assert_no_rvs(logp_combined)
155159

156160

157161
def test_broadcasted_clip_random():
@@ -162,8 +166,10 @@ def test_broadcasted_clip_random():
162166
lb_vv = lb_rv.clone()
163167
cens_x_vv = cens_x_rv.clone()
164168

165-
logp = joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv})
166-
assert_no_rvs(logp)
169+
logp = factorized_joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv})
170+
logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()])
171+
172+
assert_no_rvs(logp_combined)
167173

168174

169175
def test_fail_base_and_clip_have_values():
@@ -199,10 +205,11 @@ def test_deterministic_clipping():
199205

200206
x_vv = x_rv.clone()
201207
y_vv = y_rv.clone()
202-
logp = joint_logprob({x_rv: x_vv, y_rv: y_vv})
203-
assert_no_rvs(logp)
208+
logp = factorized_joint_logprob({x_rv: x_vv, y_rv: y_vv})
209+
logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()])
210+
assert_no_rvs(logp_combined)
204211

205-
logp_fn = pytensor.function([x_vv, y_vv], logp)
212+
logp_fn = pytensor.function([x_vv, y_vv], logp_combined)
206213
assert np.isclose(
207214
logp_fn(-1, 1),
208215
st.norm(0, 1).logpdf(-1) + st.norm(0, 1).logpdf(1),
@@ -216,10 +223,11 @@ def test_clip_transform():
216223
cens_x_vv = cens_x_rv.clone()
217224

218225
transform = TransformValuesRewrite({cens_x_vv: LogTransform()})
219-
logp = joint_logprob({cens_x_rv: cens_x_vv}, extra_rewrites=transform)
226+
logp = factorized_joint_logprob({cens_x_rv: cens_x_vv}, extra_rewrites=transform)
227+
logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()])
220228

221229
cens_x_vv_testval = -1
222-
obs_logp = logp.eval({cens_x_vv: cens_x_vv_testval})
230+
obs_logp = logp_combined.eval({cens_x_vv: cens_x_vv_testval})
223231
exp_logp = sp.stats.norm(0.5, 1).logpdf(np.exp(cens_x_vv_testval)) + cens_x_vv_testval
224232

225233
assert np.isclose(obs_logp, exp_logp)
@@ -236,8 +244,8 @@ def test_rounding(rounding_op):
236244
xr.name = "xr"
237245

238246
xr_vv = xr.clone()
239-
logp = joint_logprob({xr: xr_vv}, sum=False)
240-
assert logp is not None
247+
logprob = logp(xr, xr_vv)
248+
assert logprob is not None
241249

242250
x_sp = st.norm(loc, scale)
243251
if rounding_op == pt.round:
@@ -250,6 +258,6 @@ def test_rounding(rounding_op):
250258
raise NotImplementedError()
251259

252260
assert np.allclose(
253-
logp.eval({xr_vv: test_value}),
261+
logprob.eval({xr_vv: test_value}),
254262
expected_logp,
255263
)

tests/logprob/test_composite_logprob.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,11 @@
3939
import pytensor.tensor as pt
4040
import scipy.stats as st
4141

42+
from pymc import logp
43+
from pymc.logprob.basic import factorized_joint_logprob
4244
from pymc.logprob.censoring import MeasurableClip
4345
from pymc.logprob.rewriting import construct_ir_fgraph
4446
from pymc.testing import assert_no_rvs
45-
from tests.logprob.utils import joint_logprob
4647

4748

4849
def test_scalar_clipped_mixture():
@@ -61,9 +62,10 @@ def test_scalar_clipped_mixture():
6162
idxs_vv = idxs.clone()
6263
idxs_vv.name = "idxs_val"
6364

64-
logp = joint_logprob({idxs: idxs_vv, mix: mix_vv})
65+
logp = factorized_joint_logprob({idxs: idxs_vv, mix: mix_vv})
66+
logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()])
6567

66-
logp_fn = pytensor.function([idxs_vv, mix_vv], logp)
68+
logp_fn = pytensor.function([idxs_vv, mix_vv], logp_combined)
6769
assert logp_fn(0, 0.4) == -np.inf
6870
assert np.isclose(logp_fn(0, 0.5), st.norm.logcdf(0.5, 1) + np.log(0.6))
6971
assert np.isclose(logp_fn(0, 1.3), st.norm.logpdf(1.3, 1) + np.log(0.6))
@@ -98,8 +100,12 @@ def test_nested_scalar_mixtures():
98100
idxs12_vv = idxs12.clone()
99101
mix12_vv = mix12.clone()
100102

101-
logp = joint_logprob({idxs1: idxs1_vv, idxs2: idxs2_vv, idxs12: idxs12_vv, mix12: mix12_vv})
102-
logp_fn = pytensor.function([idxs1_vv, idxs2_vv, idxs12_vv, mix12_vv], logp)
103+
logp = factorized_joint_logprob(
104+
{idxs1: idxs1_vv, idxs2: idxs2_vv, idxs12: idxs12_vv, mix12: mix12_vv}
105+
)
106+
logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()])
107+
108+
logp_fn = pytensor.function([idxs1_vv, idxs2_vv, idxs12_vv, mix12_vv], logp_combined)
103109

104110
expected_mu_logpdf = st.norm.logpdf(0) + np.log(0.5) * 3
105111
assert np.isclose(logp_fn(0, 0, 0, -50), expected_mu_logpdf)
@@ -144,9 +150,9 @@ def test_shifted_cumsum():
144150
y.name = "y"
145151

146152
y_vv = y.clone()
147-
logp = joint_logprob({y: y_vv})
153+
logprob = logp(y, y_vv)
148154
assert np.isclose(
149-
logp.eval({y_vv: np.arange(5) + 1 + 5}),
155+
logprob.eval({y_vv: np.arange(5) + 1 + 5}).sum(),
150156
st.norm.logpdf(1) * 5,
151157
)
152158

@@ -157,8 +163,8 @@ def test_double_log_transform_rv():
157163
y_rv.name = "y"
158164

159165
y_vv = y_rv.clone()
160-
logp = joint_logprob({y_rv: y_vv}, sum=False)
161-
logp_fn = pytensor.function([y_vv], logp)
166+
logprob = logp(y_rv, y_vv)
167+
logp_fn = pytensor.function([y_vv], logprob)
162168

163169
log_log_y_val = np.asarray(0.5)
164170
log_y_val = np.exp(log_log_y_val)
@@ -178,9 +184,9 @@ def test_affine_transform_rv():
178184
y_rv.name = "y"
179185
y_vv = y_rv.clone()
180186

181-
logp = joint_logprob({y_rv: y_vv}, sum=False)
182-
assert_no_rvs(logp)
183-
logp_fn = pytensor.function([loc, scale, y_vv], logp)
187+
logprob = logp(y_rv, y_vv)
188+
assert_no_rvs(logprob)
189+
logp_fn = pytensor.function([loc, scale, y_vv], logprob)
184190

185191
loc_test_val = 4.0
186192
scale_test_val = np.full(rv_size, 0.5)
@@ -200,8 +206,8 @@ def test_affine_log_transform_rv():
200206

201207
y_vv = y_rv.clone()
202208

203-
logp = joint_logprob({y_rv: y_vv}, sum=False)
204-
logp_fn = pytensor.function([a, b, y_vv], logp)
209+
logprob = logp(y_rv, y_vv)
210+
logp_fn = pytensor.function([a, b, y_vv], logprob)
205211

206212
a_val = -1.5
207213
b_val = 3.0

tests/logprob/test_cumsum.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@
4040
import pytest
4141
import scipy.stats as st
4242

43+
from pymc import logp
44+
from pymc.logprob.basic import factorized_joint_logprob
4345
from pymc.testing import assert_no_rvs
44-
from tests.logprob.utils import joint_logprob
4546

4647

4748
@pytest.mark.parametrize(
@@ -59,12 +60,12 @@
5960
def test_normal_cumsum(size, axis):
6061
rv = pt.random.normal(0, 1, size=size).cumsum(axis)
6162
vv = rv.clone()
62-
logp = joint_logprob({rv: vv})
63-
assert_no_rvs(logp)
63+
logprob = logp(rv, vv)
64+
assert_no_rvs(logprob)
6465

6566
assert np.isclose(
6667
st.norm(0, 1).logpdf(np.ones(size)).sum(),
67-
logp.eval({vv: np.ones(size).cumsum(axis)}),
68+
logprob.eval({vv: np.ones(size).cumsum(axis)}).sum(),
6869
)
6970

7071

@@ -83,12 +84,12 @@ def test_normal_cumsum(size, axis):
8384
def test_bernoulli_cumsum(size, axis):
8485
rv = pt.random.bernoulli(0.9, size=size).cumsum(axis)
8586
vv = rv.clone()
86-
logp = joint_logprob({rv: vv})
87-
assert_no_rvs(logp)
87+
logprob = logp(rv, vv)
88+
assert_no_rvs(logprob)
8889

8990
assert np.isclose(
9091
st.bernoulli(0.9).logpmf(np.ones(size)).sum(),
91-
logp.eval({vv: np.ones(size, int).cumsum(axis)}),
92+
logprob.eval({vv: np.ones(size, int).cumsum(axis)}).sum(),
9293
)
9394

9495

@@ -97,7 +98,7 @@ def test_destructive_cumsum_fails():
9798
x_rv = pt.random.normal(size=(2, 2, 2)).cumsum()
9899
x_vv = x_rv.clone()
99100
with pytest.raises(RuntimeError, match="could not be derived"):
100-
joint_logprob({x_rv: x_vv})
101+
factorized_joint_logprob({x_rv: x_vv})
101102

102103

103104
def test_deterministic_cumsum():
@@ -108,11 +109,13 @@ def test_deterministic_cumsum():
108109

109110
x_vv = x_rv.clone()
110111
y_vv = y_rv.clone()
111-
logp = joint_logprob({x_rv: x_vv, y_rv: y_vv})
112-
assert_no_rvs(logp)
113112

114-
logp_fn = pytensor.function([x_vv, y_vv], logp)
113+
logp = factorized_joint_logprob({x_rv: x_vv, y_rv: y_vv})
114+
logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()])
115+
assert_no_rvs(logp_combined)
116+
117+
logp_fn = pytensor.function([x_vv, y_vv], logp_combined)
115118
assert np.isclose(
116-
logp_fn(np.ones(5), np.arange(5) + 1),
119+
logp_fn(np.ones(5), np.arange(5) + 1).sum(),
117120
st.norm(1, 1).logpdf(1) * 10,
118121
)

tests/logprob/test_mixture.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
from pymc.logprob.rewriting import construct_ir_fgraph
5353
from pymc.logprob.utils import dirac_delta
5454
from pymc.testing import assert_no_rvs
55-
from tests.logprob.utils import joint_logprob, scipy_logprob
55+
from tests.logprob.utils import scipy_logprob
5656

5757

5858
def test_mixture_basics():
@@ -101,7 +101,7 @@ def create_mix_model(size, axis):
101101
i_vv = env["i_vv"]
102102
M_rv = env["M_rv"]
103103
m_vv = env["m_vv"]
104-
joint_logprob({M_rv: m_vv, I_rv: i_vv})
104+
factorized_joint_logprob({M_rv: m_vv, I_rv: i_vv})
105105

106106

107107
@pytensor.config.change_flags(compute_test_value="warn")
@@ -134,9 +134,10 @@ def test_compute_test_value(op_constructor):
134134

135135
del M_rv.tag.test_value
136136

137-
M_logp = joint_logprob({M_rv: m_vv, I_rv: i_vv}, sum=False)
137+
M_logp = factorized_joint_logprob({M_rv: m_vv, I_rv: i_vv})
138+
M_logp_combined = pt.add(*M_logp.values())
138139

139-
assert isinstance(M_logp.tag.test_value, np.ndarray)
140+
assert isinstance(M_logp_combined.tag.test_value, np.ndarray)
140141

141142

142143
@pytest.mark.parametrize(
@@ -179,13 +180,14 @@ def test_hetero_mixture_binomial(p_val, size, supported):
179180
m_vv.name = "m"
180181

181182
if supported:
182-
M_logp = joint_logprob({M_rv: m_vv, I_rv: i_vv}, sum=False)
183+
M_logp = factorized_joint_logprob({M_rv: m_vv, I_rv: i_vv})
184+
M_logp_combined = pt.add(*M_logp.values())
183185
else:
184186
with pytest.raises(RuntimeError, match="could not be derived: {m}"):
185-
joint_logprob({M_rv: m_vv, I_rv: i_vv}, sum=False)
187+
factorized_joint_logprob({M_rv: m_vv, I_rv: i_vv})
186188
return
187189

188-
M_logp_fn = pytensor.function([p_at, m_vv, i_vv], M_logp)
190+
M_logp_fn = pytensor.function([p_at, m_vv, i_vv], M_logp_combined)
189191

190192
assert_no_rvs(M_logp_fn.maker.fgraph.outputs[0])
191193

@@ -936,14 +938,16 @@ def test_switch_mixture():
936938

937939
assert equal_computations(fgraph.outputs, fgraph2.outputs)
938940

939-
z1_logp = joint_logprob({Z1_rv: z_vv, I_rv: i_vv})
940-
z2_logp = joint_logprob({Z2_rv: z_vv, I_rv: i_vv})
941+
z1_logp = factorized_joint_logprob({Z1_rv: z_vv, I_rv: i_vv})
942+
z2_logp = factorized_joint_logprob({Z2_rv: z_vv, I_rv: i_vv})
943+
z1_logp_combined = pt.sum([pt.sum(factor) for factor in z1_logp.values()])
944+
z2_logp_combined = pt.sum([pt.sum(factor) for factor in z2_logp.values()])
941945

942946
# below should follow immediately from the equal_computations assertion above
943-
assert equal_computations([z1_logp], [z2_logp])
947+
assert equal_computations([z1_logp_combined], [z2_logp_combined])
944948

945-
np.testing.assert_almost_equal(0.69049938, z1_logp.eval({z_vv: -10, i_vv: 0}))
946-
np.testing.assert_almost_equal(0.69049938, z2_logp.eval({z_vv: -10, i_vv: 0}))
949+
np.testing.assert_almost_equal(0.69049938, z1_logp_combined.eval({z_vv: -10, i_vv: 0}))
950+
np.testing.assert_almost_equal(0.69049938, z2_logp_combined.eval({z_vv: -10, i_vv: 0}))
947951

948952

949953
def test_ifelse_mixture_one_component():

0 commit comments

Comments
 (0)