Skip to content

Commit 1b8bbfd

Browse files
committed
Solve conflicts in test_basic.py
1 parent ace9892 commit 1b8bbfd

File tree

1 file changed

+37
-82
lines changed

1 file changed

+37
-82
lines changed

tests/logprob/test_basic.py

Lines changed: 37 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,19 @@
6060
from pymc.logprob.utils import rvs_to_value_vars, walk_model
6161
from pymc.pytensorf import replace_rvs_by_values
6262
from pymc.testing import assert_no_rvs
63-
from tests.logprob.utils import joint_logprob
6463

6564

66-
def test_joint_logprob_basic():
67-
# A simple check for when `joint_logprob` is the same as `logprob`
65+
def test_factorized_joint_logprob_basic():
66+
# A simple check for when `factorized_joint_logprob` is the same as `logprob`
6867
a = pt.random.uniform(0.0, 1.0)
6968
a.name = "a"
7069
a_value_var = a.clone()
7170

72-
a_logp = joint_logprob({a: a_value_var}, sum=False)
71+
a_logp = factorized_joint_logprob({a: a_value_var})
72+
a_logp_comb = tuple(a_logp.values())[0]
7373
a_logp_exp = logp(a, a_value_var)
7474

75-
assert equal_computations([a_logp], [a_logp_exp])
75+
assert equal_computations([a_logp_comb], [a_logp_exp])
7676

7777
# Let's try a hierarchical model
7878
sigma = pt.random.invgamma(0.5, 0.5)
@@ -81,7 +81,8 @@ def test_joint_logprob_basic():
8181
sigma_value_var = sigma.clone()
8282
y_value_var = Y.clone()
8383

84-
total_ll = joint_logprob({Y: y_value_var, sigma: sigma_value_var}, sum=False)
84+
total_ll = factorized_joint_logprob({Y: y_value_var, sigma: sigma_value_var})
85+
total_ll_combined = pt.add(*total_ll.values())
8586

8687
# We need to replace the reference to `sigma` in `Y` with its value
8788
# variable
@@ -92,7 +93,7 @@ def test_joint_logprob_basic():
9293
)
9394
total_ll_exp = logp(sigma, sigma_value_var) + ll_Y
9495

95-
assert equal_computations([total_ll], [total_ll_exp])
96+
assert equal_computations([total_ll_combined], [total_ll_exp])
9697

9798
# Now, make sure we can compute a joint log-probability for a hierarchical
9899
# model with some non-`RandomVariable` nodes
@@ -105,42 +106,46 @@ def test_joint_logprob_basic():
105106
b_value_var = b.clone()
106107
c_value_var = c.clone()
107108

108-
b_logp = joint_logprob({a: a_value_var, b: b_value_var, c: c_value_var})
109+
b_logp = factorized_joint_logprob({a: a_value_var, b: b_value_var, c: c_value_var})
110+
b_logp_combined = pt.sum([pt.sum(factor) for factor in b_logp.values()])
109111

110112
# There shouldn't be any `RandomVariable`s in the resulting graph
111-
assert_no_rvs(b_logp)
113+
assert_no_rvs(b_logp_combined)
112114

113-
res_ancestors = list(walk_model((b_logp,), walk_past_rvs=True))
115+
res_ancestors = list(walk_model((b_logp_combined,), walk_past_rvs=True))
114116
assert b_value_var in res_ancestors
115117
assert c_value_var in res_ancestors
116118
assert a_value_var in res_ancestors
117119

118120

119-
def test_joint_logprob_multi_obs():
121+
def test_factorized_joint_logprob_multi_obs():
120122
a = pt.random.uniform(0.0, 1.0)
121123
b = pt.random.normal(0.0, 1.0)
122124

123125
a_val = a.clone()
124126
b_val = b.clone()
125127

126-
logp_res = joint_logprob({a: a_val, b: b_val}, sum=False)
128+
logp_res = factorized_joint_logprob({a: a_val, b: b_val})
129+
logp_res_combined = pt.add(*logp_res.values())
127130
logp_exp = logp(a, a_val) + logp(b, b_val)
128131

129-
assert equal_computations([logp_res], [logp_exp])
132+
assert equal_computations([logp_res_combined], [logp_exp])
130133

131134
x = pt.random.normal(0, 1)
132135
y = pt.random.normal(x, 1)
133136

134137
x_val = x.clone()
135138
y_val = y.clone()
136139

137-
logp_res = joint_logprob({x: x_val, y: y_val})
138-
exp_logp = joint_logprob({x: x_val, y: y_val})
140+
logp_res = factorized_joint_logprob({x: x_val, y: y_val})
141+
exp_logp = factorized_joint_logprob({x: x_val, y: y_val})
142+
logp_res_comb = pt.sum([pt.sum(factor) for factor in logp_res.values()])
143+
exp_logp_comb = pt.sum([pt.sum(factor) for factor in exp_logp.values()])
139144

140-
assert equal_computations([logp_res], [exp_logp])
145+
assert equal_computations([logp_res_comb], [exp_logp_comb])
141146

142147

143-
def test_joint_logprob_diff_dims():
148+
def test_factorized_joint_logprob_diff_dims():
144149
M = pt.matrix("M")
145150
x = pt.random.normal(0, 1, size=M.shape[1], name="X")
146151
y = pt.random.normal(M.dot(x), 1, name="Y")
@@ -150,14 +155,15 @@ def test_joint_logprob_diff_dims():
150155
y_vv = y.clone()
151156
y_vv.name = "y"
152157

153-
logp = joint_logprob({x: x_vv, y: y_vv})
158+
logp = factorized_joint_logprob({x: x_vv, y: y_vv})
159+
logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()])
154160

155161
M_val = np.random.normal(size=(10, 3))
156162
x_val = np.random.normal(size=(3,))
157163
y_val = np.random.normal(size=(10,))
158164

159165
point = {M: M_val, x_vv: x_val, y_vv: y_val}
160-
logp_val = logp.eval(point)
166+
logp_val = logp_combined.eval(point)
161167

162168
exp_logp_val = (
163169
sp.norm.logpdf(x_val, 0, 1).sum() + sp.norm.logpdf(y_val, M_val.dot(x_val), 1).sum()
@@ -179,60 +185,6 @@ def test_incsubtensor_original_values_output_dict():
179185
assert vv in logp_dict
180186

181187

182-
def test_joint_logprob_subtensor():
183-
"""Make sure we can compute a joint log-probability for ``Y[I]`` where ``Y`` and ``I`` are random variables."""
184-
185-
size = 5
186-
187-
mu_base = np.power(10, np.arange(np.prod(size))).reshape(size)
188-
mu = np.stack([mu_base, -mu_base])
189-
sigma = 0.001
190-
rng = pytensor.shared(np.random.RandomState(232), borrow=True)
191-
192-
A_rv = pt.random.normal(mu, sigma, rng=rng)
193-
A_rv.name = "A"
194-
195-
p = 0.5
196-
197-
I_rv = pt.random.bernoulli(p, size=size, rng=rng)
198-
I_rv.name = "I"
199-
200-
A_idx = A_rv[I_rv, pt.ogrid[A_rv.shape[-1] :]]
201-
202-
assert isinstance(A_idx.owner.op, (Subtensor, AdvancedSubtensor, AdvancedSubtensor1))
203-
204-
A_idx_value_var = A_idx.type()
205-
A_idx_value_var.name = "A_idx_value"
206-
207-
I_value_var = I_rv.type()
208-
I_value_var.name = "I_value"
209-
210-
A_idx_logp = joint_logprob({A_idx: A_idx_value_var, I_rv: I_value_var}, sum=False)
211-
212-
logp_vals_fn = pytensor.function([A_idx_value_var, I_value_var], A_idx_logp)
213-
214-
# The compiled graph should not contain any `RandomVariables`
215-
assert_no_rvs(logp_vals_fn.maker.fgraph.outputs[0])
216-
217-
decimals = 6 if pytensor.config.floatX == "float64" else 4
218-
219-
test_val_rng = np.random.RandomState(3238)
220-
221-
for i in range(10):
222-
bern_sp = sp.bernoulli(p)
223-
I_value = bern_sp.rvs(size=size, random_state=test_val_rng).astype(I_rv.dtype)
224-
225-
norm_sp = sp.norm(mu[I_value, np.ogrid[mu.shape[1] :]], sigma)
226-
A_idx_value = norm_sp.rvs(random_state=test_val_rng).astype(A_idx.dtype)
227-
228-
exp_obs_logps = norm_sp.logpdf(A_idx_value)
229-
exp_obs_logps += bern_sp.logpmf(I_value)
230-
231-
logp_vals = logp_vals_fn(A_idx_value, I_value)
232-
233-
np.testing.assert_almost_equal(logp_vals, exp_obs_logps, decimal=decimals)
234-
235-
236188
def test_persist_inputs():
237189
"""Make sure we don't unnecessarily clone variables."""
238190
x = pt.scalar("x")
@@ -242,24 +194,27 @@ def test_persist_inputs():
242194
beta_vv = beta_rv.type()
243195
y_vv = Y_rv.clone()
244196

245-
logp = joint_logprob({beta_rv: beta_vv, Y_rv: y_vv})
197+
logp = factorized_joint_logprob({beta_rv: beta_vv, Y_rv: y_vv})
198+
logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()])
246199

247-
assert x in ancestors([logp])
200+
assert x in ancestors([logp_combined])
248201

249202
# Make sure we don't clone value variables when they're graphs.
250203
y_vv_2 = y_vv * 2
251-
logp_2 = joint_logprob({beta_rv: beta_vv, Y_rv: y_vv_2})
204+
logp_2 = factorized_joint_logprob({beta_rv: beta_vv, Y_rv: y_vv_2})
205+
logp_2_combined = pt.sum([pt.sum(factor) for factor in logp_2.values()])
252206

253-
assert y_vv in ancestors([logp_2])
254-
assert y_vv_2 in ancestors([logp_2])
207+
assert y_vv in ancestors([logp_2_combined])
208+
assert y_vv_2 in ancestors([logp_2_combined])
255209

256210
# Even when they are random
257211
y_vv = pt.random.normal(name="y_vv2")
258212
y_vv_2 = y_vv * 2
259-
logp_2 = joint_logprob({beta_rv: beta_vv, Y_rv: y_vv_2})
213+
logp_2 = factorized_joint_logprob({beta_rv: beta_vv, Y_rv: y_vv_2})
214+
logp_2_combined = pt.sum([pt.sum(factor) for factor in logp_2.values()])
260215

261-
assert y_vv in ancestors([logp_2])
262-
assert y_vv_2 in ancestors([logp_2])
216+
assert y_vv in ancestors([logp_2_combined])
217+
assert y_vv_2 in ancestors([logp_2_combined])
263218

264219

265220
def test_warn_random_found_factorized_joint_logprob():
@@ -284,7 +239,7 @@ def test_multiple_rvs_to_same_value_raises():
284239

285240
msg = "More than one logprob factor was assigned to the value var x"
286241
with pytest.raises(ValueError, match=msg):
287-
joint_logprob({x_rv1: x, x_rv2: x})
242+
factorized_joint_logprob({x_rv1: x, x_rv2: x})
288243

289244

290245
def test_joint_logp_basic():

0 commit comments

Comments
 (0)