Skip to content

Commit 0073639

Browse files
authored
Remove joint_logprob function from tests.logprob.utils (#6650)
* Remove joint_logprob function from tests.logprob.utils * Move the joint logprob test for subtensors to test_mixture.py
1 parent c57769c commit 0073639

File tree

10 files changed

+292
-256
lines changed

10 files changed

+292
-256
lines changed

tests/logprob/test_basic.py

Lines changed: 37 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,19 @@
6060
from pymc.logprob.utils import rvs_to_value_vars, walk_model
6161
from pymc.pytensorf import replace_rvs_by_values
6262
from pymc.testing import assert_no_rvs
63-
from tests.logprob.utils import joint_logprob
6463

6564

66-
def test_joint_logprob_basic():
67-
# A simple check for when `joint_logprob` is the same as `logprob`
65+
def test_factorized_joint_logprob_basic():
66+
# A simple check for when `factorized_joint_logprob` is the same as `logprob`
6867
a = pt.random.uniform(0.0, 1.0)
6968
a.name = "a"
7069
a_value_var = a.clone()
7170

72-
a_logp = joint_logprob({a: a_value_var}, sum=False)
71+
a_logp = factorized_joint_logprob({a: a_value_var})
72+
a_logp_comb = tuple(a_logp.values())[0]
7373
a_logp_exp = logp(a, a_value_var)
7474

75-
assert equal_computations([a_logp], [a_logp_exp])
75+
assert equal_computations([a_logp_comb], [a_logp_exp])
7676

7777
# Let's try a hierarchical model
7878
sigma = pt.random.invgamma(0.5, 0.5)
@@ -81,7 +81,8 @@ def test_joint_logprob_basic():
8181
sigma_value_var = sigma.clone()
8282
y_value_var = Y.clone()
8383

84-
total_ll = joint_logprob({Y: y_value_var, sigma: sigma_value_var}, sum=False)
84+
total_ll = factorized_joint_logprob({Y: y_value_var, sigma: sigma_value_var})
85+
total_ll_combined = pt.add(*total_ll.values())
8586

8687
# We need to replace the reference to `sigma` in `Y` with its value
8788
# variable
@@ -92,7 +93,7 @@ def test_joint_logprob_basic():
9293
)
9394
total_ll_exp = logp(sigma, sigma_value_var) + ll_Y
9495

95-
assert equal_computations([total_ll], [total_ll_exp])
96+
assert equal_computations([total_ll_combined], [total_ll_exp])
9697

9798
# Now, make sure we can compute a joint log-probability for a hierarchical
9899
# model with some non-`RandomVariable` nodes
@@ -105,42 +106,46 @@ def test_joint_logprob_basic():
105106
b_value_var = b.clone()
106107
c_value_var = c.clone()
107108

108-
b_logp = joint_logprob({a: a_value_var, b: b_value_var, c: c_value_var})
109+
b_logp = factorized_joint_logprob({a: a_value_var, b: b_value_var, c: c_value_var})
110+
b_logp_combined = pt.sum([pt.sum(factor) for factor in b_logp.values()])
109111

110112
# There shouldn't be any `RandomVariable`s in the resulting graph
111-
assert_no_rvs(b_logp)
113+
assert_no_rvs(b_logp_combined)
112114

113-
res_ancestors = list(walk_model((b_logp,), walk_past_rvs=True))
115+
res_ancestors = list(walk_model((b_logp_combined,), walk_past_rvs=True))
114116
assert b_value_var in res_ancestors
115117
assert c_value_var in res_ancestors
116118
assert a_value_var in res_ancestors
117119

118120

119-
def test_joint_logprob_multi_obs():
121+
def test_factorized_joint_logprob_multi_obs():
120122
a = pt.random.uniform(0.0, 1.0)
121123
b = pt.random.normal(0.0, 1.0)
122124

123125
a_val = a.clone()
124126
b_val = b.clone()
125127

126-
logp_res = joint_logprob({a: a_val, b: b_val}, sum=False)
128+
logp_res = factorized_joint_logprob({a: a_val, b: b_val})
129+
logp_res_combined = pt.add(*logp_res.values())
127130
logp_exp = logp(a, a_val) + logp(b, b_val)
128131

129-
assert equal_computations([logp_res], [logp_exp])
132+
assert equal_computations([logp_res_combined], [logp_exp])
130133

131134
x = pt.random.normal(0, 1)
132135
y = pt.random.normal(x, 1)
133136

134137
x_val = x.clone()
135138
y_val = y.clone()
136139

137-
logp_res = joint_logprob({x: x_val, y: y_val})
138-
exp_logp = joint_logprob({x: x_val, y: y_val})
140+
logp_res = factorized_joint_logprob({x: x_val, y: y_val})
141+
exp_logp = factorized_joint_logprob({x: x_val, y: y_val})
142+
logp_res_comb = pt.sum([pt.sum(factor) for factor in logp_res.values()])
143+
exp_logp_comb = pt.sum([pt.sum(factor) for factor in exp_logp.values()])
139144

140-
assert equal_computations([logp_res], [exp_logp])
145+
assert equal_computations([logp_res_comb], [exp_logp_comb])
141146

142147

143-
def test_joint_logprob_diff_dims():
148+
def test_factorized_joint_logprob_diff_dims():
144149
M = pt.matrix("M")
145150
x = pt.random.normal(0, 1, size=M.shape[1], name="X")
146151
y = pt.random.normal(M.dot(x), 1, name="Y")
@@ -150,14 +155,15 @@ def test_joint_logprob_diff_dims():
150155
y_vv = y.clone()
151156
y_vv.name = "y"
152157

153-
logp = joint_logprob({x: x_vv, y: y_vv})
158+
logp = factorized_joint_logprob({x: x_vv, y: y_vv})
159+
logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()])
154160

155161
M_val = np.random.normal(size=(10, 3))
156162
x_val = np.random.normal(size=(3,))
157163
y_val = np.random.normal(size=(10,))
158164

159165
point = {M: M_val, x_vv: x_val, y_vv: y_val}
160-
logp_val = logp.eval(point)
166+
logp_val = logp_combined.eval(point)
161167

162168
exp_logp_val = (
163169
sp.norm.logpdf(x_val, 0, 1).sum() + sp.norm.logpdf(y_val, M_val.dot(x_val), 1).sum()
@@ -179,60 +185,6 @@ def test_incsubtensor_original_values_output_dict():
179185
assert vv in logp_dict
180186

181187

182-
def test_joint_logprob_subtensor():
183-
"""Make sure we can compute a joint log-probability for ``Y[I]`` where ``Y`` and ``I`` are random variables."""
184-
185-
size = 5
186-
187-
mu_base = np.power(10, np.arange(np.prod(size))).reshape(size)
188-
mu = np.stack([mu_base, -mu_base])
189-
sigma = 0.001
190-
rng = pytensor.shared(np.random.RandomState(232), borrow=True)
191-
192-
A_rv = pt.random.normal(mu, sigma, rng=rng)
193-
A_rv.name = "A"
194-
195-
p = 0.5
196-
197-
I_rv = pt.random.bernoulli(p, size=size, rng=rng)
198-
I_rv.name = "I"
199-
200-
A_idx = A_rv[I_rv, pt.ogrid[A_rv.shape[-1] :]]
201-
202-
assert isinstance(A_idx.owner.op, (Subtensor, AdvancedSubtensor, AdvancedSubtensor1))
203-
204-
A_idx_value_var = A_idx.type()
205-
A_idx_value_var.name = "A_idx_value"
206-
207-
I_value_var = I_rv.type()
208-
I_value_var.name = "I_value"
209-
210-
A_idx_logp = joint_logprob({A_idx: A_idx_value_var, I_rv: I_value_var}, sum=False)
211-
212-
logp_vals_fn = pytensor.function([A_idx_value_var, I_value_var], A_idx_logp)
213-
214-
# The compiled graph should not contain any `RandomVariables`
215-
assert_no_rvs(logp_vals_fn.maker.fgraph.outputs[0])
216-
217-
decimals = 6 if pytensor.config.floatX == "float64" else 4
218-
219-
test_val_rng = np.random.RandomState(3238)
220-
221-
for i in range(10):
222-
bern_sp = sp.bernoulli(p)
223-
I_value = bern_sp.rvs(size=size, random_state=test_val_rng).astype(I_rv.dtype)
224-
225-
norm_sp = sp.norm(mu[I_value, np.ogrid[mu.shape[1] :]], sigma)
226-
A_idx_value = norm_sp.rvs(random_state=test_val_rng).astype(A_idx.dtype)
227-
228-
exp_obs_logps = norm_sp.logpdf(A_idx_value)
229-
exp_obs_logps += bern_sp.logpmf(I_value)
230-
231-
logp_vals = logp_vals_fn(A_idx_value, I_value)
232-
233-
np.testing.assert_almost_equal(logp_vals, exp_obs_logps, decimal=decimals)
234-
235-
236188
def test_persist_inputs():
237189
"""Make sure we don't unnecessarily clone variables."""
238190
x = pt.scalar("x")
@@ -242,24 +194,27 @@ def test_persist_inputs():
242194
beta_vv = beta_rv.type()
243195
y_vv = Y_rv.clone()
244196

245-
logp = joint_logprob({beta_rv: beta_vv, Y_rv: y_vv})
197+
logp = factorized_joint_logprob({beta_rv: beta_vv, Y_rv: y_vv})
198+
logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()])
246199

247-
assert x in ancestors([logp])
200+
assert x in ancestors([logp_combined])
248201

249202
# Make sure we don't clone value variables when they're graphs.
250203
y_vv_2 = y_vv * 2
251-
logp_2 = joint_logprob({beta_rv: beta_vv, Y_rv: y_vv_2})
204+
logp_2 = factorized_joint_logprob({beta_rv: beta_vv, Y_rv: y_vv_2})
205+
logp_2_combined = pt.sum([pt.sum(factor) for factor in logp_2.values()])
252206

253-
assert y_vv in ancestors([logp_2])
254-
assert y_vv_2 in ancestors([logp_2])
207+
assert y_vv in ancestors([logp_2_combined])
208+
assert y_vv_2 in ancestors([logp_2_combined])
255209

256210
# Even when they are random
257211
y_vv = pt.random.normal(name="y_vv2")
258212
y_vv_2 = y_vv * 2
259-
logp_2 = joint_logprob({beta_rv: beta_vv, Y_rv: y_vv_2})
213+
logp_2 = factorized_joint_logprob({beta_rv: beta_vv, Y_rv: y_vv_2})
214+
logp_2_combined = pt.sum([pt.sum(factor) for factor in logp_2.values()])
260215

261-
assert y_vv in ancestors([logp_2])
262-
assert y_vv_2 in ancestors([logp_2])
216+
assert y_vv in ancestors([logp_2_combined])
217+
assert y_vv_2 in ancestors([logp_2_combined])
263218

264219

265220
def test_warn_random_found_factorized_joint_logprob():
@@ -284,7 +239,7 @@ def test_multiple_rvs_to_same_value_raises():
284239

285240
msg = "More than one logprob factor was assigned to the value var x"
286241
with pytest.raises(ValueError, match=msg):
287-
joint_logprob({x_rv1: x, x_rv2: x})
242+
factorized_joint_logprob({x_rv1: x, x_rv2: x})
288243

289244

290245
def test_joint_logp_basic():

tests/logprob/test_censoring.py

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@
4141
import scipy as sp
4242
import scipy.stats as st
4343

44+
from pymc import logp
4445
from pymc.logprob import factorized_joint_logprob
4546
from pymc.logprob.transforms import LogTransform, TransformValuesRewrite
4647
from pymc.testing import assert_no_rvs
47-
from tests.logprob.utils import joint_logprob
4848

4949

5050
@pytensor.config.change_flags(compute_test_value="raise")
@@ -55,10 +55,10 @@ def test_continuous_rv_clip():
5555
cens_x_vv = cens_x_rv.clone()
5656
cens_x_vv.tag.test_value = 0
5757

58-
logp = joint_logprob({cens_x_rv: cens_x_vv})
59-
assert_no_rvs(logp)
58+
logprob = pt.sum(logp(cens_x_rv, cens_x_vv))
59+
assert_no_rvs(logprob)
6060

61-
logp_fn = pytensor.function([cens_x_vv], logp)
61+
logp_fn = pytensor.function([cens_x_vv], logprob)
6262
ref_scipy = st.norm(0.5, 1)
6363

6464
assert logp_fn(-3) == -np.inf
@@ -75,10 +75,10 @@ def test_discrete_rv_clip():
7575

7676
cens_x_vv = cens_x_rv.clone()
7777

78-
logp = joint_logprob({cens_x_rv: cens_x_vv})
79-
assert_no_rvs(logp)
78+
logprob = pt.sum(logp(cens_x_rv, cens_x_vv))
79+
assert_no_rvs(logprob)
8080

81-
logp_fn = pytensor.function([cens_x_vv], logp)
81+
logp_fn = pytensor.function([cens_x_vv], logprob)
8282
ref_scipy = st.poisson(2)
8383

8484
assert logp_fn(0) == -np.inf
@@ -97,8 +97,8 @@ def test_one_sided_clip():
9797
lb_cens_x_vv = lb_cens_x_rv.clone()
9898
ub_cens_x_vv = ub_cens_x_rv.clone()
9999

100-
lb_logp = joint_logprob({lb_cens_x_rv: lb_cens_x_vv})
101-
ub_logp = joint_logprob({ub_cens_x_rv: ub_cens_x_vv})
100+
lb_logp = pt.sum(logp(lb_cens_x_rv, lb_cens_x_vv))
101+
ub_logp = pt.sum(logp(ub_cens_x_rv, ub_cens_x_vv))
102102
assert_no_rvs(lb_logp)
103103
assert_no_rvs(ub_logp)
104104

@@ -117,10 +117,10 @@ def test_useless_clip():
117117

118118
cens_x_vv = cens_x_rv.clone()
119119

120-
logp = joint_logprob({cens_x_rv: cens_x_vv}, sum=False)
121-
assert_no_rvs(logp)
120+
logprob = logp(cens_x_rv, cens_x_vv)
121+
assert_no_rvs(logprob)
122122

123-
logp_fn = pytensor.function([cens_x_vv], logp)
123+
logp_fn = pytensor.function([cens_x_vv], logprob)
124124
ref_scipy = st.norm(0.5, 1)
125125

126126
np.testing.assert_allclose(logp_fn([-2, 0, 2]), ref_scipy.logpdf([-2, 0, 2]))
@@ -133,10 +133,12 @@ def test_random_clip():
133133

134134
lb_vv = lb_rv.clone()
135135
cens_x_vv = cens_x_rv.clone()
136-
logp = joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv}, sum=False)
137-
assert_no_rvs(logp)
136+
logp = factorized_joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv})
137+
logp_combined = pt.add(*logp.values())
138138

139-
logp_fn = pytensor.function([lb_vv, cens_x_vv], logp)
139+
assert_no_rvs(logp_combined)
140+
141+
logp_fn = pytensor.function([lb_vv, cens_x_vv], logp_combined)
140142
res = logp_fn([0, -1], [-1, -1])
141143
assert res[0] == -np.inf
142144
assert res[1] != -np.inf
@@ -150,8 +152,10 @@ def test_broadcasted_clip_constant():
150152
lb_vv = lb_rv.clone()
151153
cens_x_vv = cens_x_rv.clone()
152154

153-
logp = joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv})
154-
assert_no_rvs(logp)
155+
logp = factorized_joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv})
156+
logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()])
157+
158+
assert_no_rvs(logp_combined)
155159

156160

157161
def test_broadcasted_clip_random():
@@ -162,8 +166,10 @@ def test_broadcasted_clip_random():
162166
lb_vv = lb_rv.clone()
163167
cens_x_vv = cens_x_rv.clone()
164168

165-
logp = joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv})
166-
assert_no_rvs(logp)
169+
logp = factorized_joint_logprob({cens_x_rv: cens_x_vv, lb_rv: lb_vv})
170+
logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()])
171+
172+
assert_no_rvs(logp_combined)
167173

168174

169175
def test_fail_base_and_clip_have_values():
@@ -199,10 +205,11 @@ def test_deterministic_clipping():
199205

200206
x_vv = x_rv.clone()
201207
y_vv = y_rv.clone()
202-
logp = joint_logprob({x_rv: x_vv, y_rv: y_vv})
203-
assert_no_rvs(logp)
208+
logp = factorized_joint_logprob({x_rv: x_vv, y_rv: y_vv})
209+
logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()])
210+
assert_no_rvs(logp_combined)
204211

205-
logp_fn = pytensor.function([x_vv, y_vv], logp)
212+
logp_fn = pytensor.function([x_vv, y_vv], logp_combined)
206213
assert np.isclose(
207214
logp_fn(-1, 1),
208215
st.norm(0, 1).logpdf(-1) + st.norm(0, 1).logpdf(1),
@@ -216,10 +223,11 @@ def test_clip_transform():
216223
cens_x_vv = cens_x_rv.clone()
217224

218225
transform = TransformValuesRewrite({cens_x_vv: LogTransform()})
219-
logp = joint_logprob({cens_x_rv: cens_x_vv}, extra_rewrites=transform)
226+
logp = factorized_joint_logprob({cens_x_rv: cens_x_vv}, extra_rewrites=transform)
227+
logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()])
220228

221229
cens_x_vv_testval = -1
222-
obs_logp = logp.eval({cens_x_vv: cens_x_vv_testval})
230+
obs_logp = logp_combined.eval({cens_x_vv: cens_x_vv_testval})
223231
exp_logp = sp.stats.norm(0.5, 1).logpdf(np.exp(cens_x_vv_testval)) + cens_x_vv_testval
224232

225233
assert np.isclose(obs_logp, exp_logp)
@@ -236,8 +244,8 @@ def test_rounding(rounding_op):
236244
xr.name = "xr"
237245

238246
xr_vv = xr.clone()
239-
logp = joint_logprob({xr: xr_vv}, sum=False)
240-
assert logp is not None
247+
logprob = logp(xr, xr_vv)
248+
assert logprob is not None
241249

242250
x_sp = st.norm(loc, scale)
243251
if rounding_op == pt.round:
@@ -250,6 +258,6 @@ def test_rounding(rounding_op):
250258
raise NotImplementedError()
251259

252260
assert np.allclose(
253-
logp.eval({xr_vv: test_value}),
261+
logprob.eval({xr_vv: test_value}),
254262
expected_logp,
255263
)

0 commit comments

Comments
 (0)