60
60
from pymc .logprob .utils import rvs_to_value_vars , walk_model
61
61
from pymc .pytensorf import replace_rvs_by_values
62
62
from pymc .testing import assert_no_rvs
63
- from tests .logprob .utils import joint_logprob
64
63
65
64
66
- def test_joint_logprob_basic ():
67
- # A simple check for when `joint_logprob ` is the same as `logprob`
65
+ def test_factorized_joint_logprob_basic ():
66
+ # A simple check for when `factorized_joint_logprob ` is the same as `logprob`
68
67
a = pt .random .uniform (0.0 , 1.0 )
69
68
a .name = "a"
70
69
a_value_var = a .clone ()
71
70
72
- a_logp = joint_logprob ({a : a_value_var }, sum = False )
71
+ a_logp = factorized_joint_logprob ({a : a_value_var })
72
+ a_logp_comb = tuple (a_logp .values ())[0 ]
73
73
a_logp_exp = logp (a , a_value_var )
74
74
75
- assert equal_computations ([a_logp ], [a_logp_exp ])
75
+ assert equal_computations ([a_logp_comb ], [a_logp_exp ])
76
76
77
77
# Let's try a hierarchical model
78
78
sigma = pt .random .invgamma (0.5 , 0.5 )
@@ -81,7 +81,8 @@ def test_joint_logprob_basic():
81
81
sigma_value_var = sigma .clone ()
82
82
y_value_var = Y .clone ()
83
83
84
- total_ll = joint_logprob ({Y : y_value_var , sigma : sigma_value_var }, sum = False )
84
+ total_ll = factorized_joint_logprob ({Y : y_value_var , sigma : sigma_value_var })
85
+ total_ll_combined = pt .add (* total_ll .values ())
85
86
86
87
# We need to replace the reference to `sigma` in `Y` with its value
87
88
# variable
@@ -92,7 +93,7 @@ def test_joint_logprob_basic():
92
93
)
93
94
total_ll_exp = logp (sigma , sigma_value_var ) + ll_Y
94
95
95
- assert equal_computations ([total_ll ], [total_ll_exp ])
96
+ assert equal_computations ([total_ll_combined ], [total_ll_exp ])
96
97
97
98
# Now, make sure we can compute a joint log-probability for a hierarchical
98
99
# model with some non-`RandomVariable` nodes
@@ -105,42 +106,46 @@ def test_joint_logprob_basic():
105
106
b_value_var = b .clone ()
106
107
c_value_var = c .clone ()
107
108
108
- b_logp = joint_logprob ({a : a_value_var , b : b_value_var , c : c_value_var })
109
+ b_logp = factorized_joint_logprob ({a : a_value_var , b : b_value_var , c : c_value_var })
110
+ b_logp_combined = pt .sum ([pt .sum (factor ) for factor in b_logp .values ()])
109
111
110
112
# There shouldn't be any `RandomVariable`s in the resulting graph
111
- assert_no_rvs (b_logp )
113
+ assert_no_rvs (b_logp_combined )
112
114
113
- res_ancestors = list (walk_model ((b_logp ,), walk_past_rvs = True ))
115
+ res_ancestors = list (walk_model ((b_logp_combined ,), walk_past_rvs = True ))
114
116
assert b_value_var in res_ancestors
115
117
assert c_value_var in res_ancestors
116
118
assert a_value_var in res_ancestors
117
119
118
120
119
- def test_joint_logprob_multi_obs ():
121
+ def test_factorized_joint_logprob_multi_obs ():
120
122
a = pt .random .uniform (0.0 , 1.0 )
121
123
b = pt .random .normal (0.0 , 1.0 )
122
124
123
125
a_val = a .clone ()
124
126
b_val = b .clone ()
125
127
126
- logp_res = joint_logprob ({a : a_val , b : b_val }, sum = False )
128
+ logp_res = factorized_joint_logprob ({a : a_val , b : b_val })
129
+ logp_res_combined = pt .add (* logp_res .values ())
127
130
logp_exp = logp (a , a_val ) + logp (b , b_val )
128
131
129
- assert equal_computations ([logp_res ], [logp_exp ])
132
+ assert equal_computations ([logp_res_combined ], [logp_exp ])
130
133
131
134
x = pt .random .normal (0 , 1 )
132
135
y = pt .random .normal (x , 1 )
133
136
134
137
x_val = x .clone ()
135
138
y_val = y .clone ()
136
139
137
- logp_res = joint_logprob ({x : x_val , y : y_val })
138
- exp_logp = joint_logprob ({x : x_val , y : y_val })
140
+ logp_res = factorized_joint_logprob ({x : x_val , y : y_val })
141
+ exp_logp = factorized_joint_logprob ({x : x_val , y : y_val })
142
+ logp_res_comb = pt .sum ([pt .sum (factor ) for factor in logp_res .values ()])
143
+ exp_logp_comb = pt .sum ([pt .sum (factor ) for factor in exp_logp .values ()])
139
144
140
- assert equal_computations ([logp_res ], [exp_logp ])
145
+ assert equal_computations ([logp_res_comb ], [exp_logp_comb ])
141
146
142
147
143
- def test_joint_logprob_diff_dims ():
148
+ def test_factorized_joint_logprob_diff_dims ():
144
149
M = pt .matrix ("M" )
145
150
x = pt .random .normal (0 , 1 , size = M .shape [1 ], name = "X" )
146
151
y = pt .random .normal (M .dot (x ), 1 , name = "Y" )
@@ -150,14 +155,15 @@ def test_joint_logprob_diff_dims():
150
155
y_vv = y .clone ()
151
156
y_vv .name = "y"
152
157
153
- logp = joint_logprob ({x : x_vv , y : y_vv })
158
+ logp = factorized_joint_logprob ({x : x_vv , y : y_vv })
159
+ logp_combined = pt .sum ([pt .sum (factor ) for factor in logp .values ()])
154
160
155
161
M_val = np .random .normal (size = (10 , 3 ))
156
162
x_val = np .random .normal (size = (3 ,))
157
163
y_val = np .random .normal (size = (10 ,))
158
164
159
165
point = {M : M_val , x_vv : x_val , y_vv : y_val }
160
- logp_val = logp .eval (point )
166
+ logp_val = logp_combined .eval (point )
161
167
162
168
exp_logp_val = (
163
169
sp .norm .logpdf (x_val , 0 , 1 ).sum () + sp .norm .logpdf (y_val , M_val .dot (x_val ), 1 ).sum ()
@@ -179,60 +185,6 @@ def test_incsubtensor_original_values_output_dict():
179
185
assert vv in logp_dict
180
186
181
187
182
- def test_joint_logprob_subtensor ():
183
- """Make sure we can compute a joint log-probability for ``Y[I]`` where ``Y`` and ``I`` are random variables."""
184
-
185
- size = 5
186
-
187
- mu_base = np .power (10 , np .arange (np .prod (size ))).reshape (size )
188
- mu = np .stack ([mu_base , - mu_base ])
189
- sigma = 0.001
190
- rng = pytensor .shared (np .random .RandomState (232 ), borrow = True )
191
-
192
- A_rv = pt .random .normal (mu , sigma , rng = rng )
193
- A_rv .name = "A"
194
-
195
- p = 0.5
196
-
197
- I_rv = pt .random .bernoulli (p , size = size , rng = rng )
198
- I_rv .name = "I"
199
-
200
- A_idx = A_rv [I_rv , pt .ogrid [A_rv .shape [- 1 ] :]]
201
-
202
- assert isinstance (A_idx .owner .op , (Subtensor , AdvancedSubtensor , AdvancedSubtensor1 ))
203
-
204
- A_idx_value_var = A_idx .type ()
205
- A_idx_value_var .name = "A_idx_value"
206
-
207
- I_value_var = I_rv .type ()
208
- I_value_var .name = "I_value"
209
-
210
- A_idx_logp = joint_logprob ({A_idx : A_idx_value_var , I_rv : I_value_var }, sum = False )
211
-
212
- logp_vals_fn = pytensor .function ([A_idx_value_var , I_value_var ], A_idx_logp )
213
-
214
- # The compiled graph should not contain any `RandomVariables`
215
- assert_no_rvs (logp_vals_fn .maker .fgraph .outputs [0 ])
216
-
217
- decimals = 6 if pytensor .config .floatX == "float64" else 4
218
-
219
- test_val_rng = np .random .RandomState (3238 )
220
-
221
- for i in range (10 ):
222
- bern_sp = sp .bernoulli (p )
223
- I_value = bern_sp .rvs (size = size , random_state = test_val_rng ).astype (I_rv .dtype )
224
-
225
- norm_sp = sp .norm (mu [I_value , np .ogrid [mu .shape [1 ] :]], sigma )
226
- A_idx_value = norm_sp .rvs (random_state = test_val_rng ).astype (A_idx .dtype )
227
-
228
- exp_obs_logps = norm_sp .logpdf (A_idx_value )
229
- exp_obs_logps += bern_sp .logpmf (I_value )
230
-
231
- logp_vals = logp_vals_fn (A_idx_value , I_value )
232
-
233
- np .testing .assert_almost_equal (logp_vals , exp_obs_logps , decimal = decimals )
234
-
235
-
236
188
def test_persist_inputs ():
237
189
"""Make sure we don't unnecessarily clone variables."""
238
190
x = pt .scalar ("x" )
@@ -242,24 +194,27 @@ def test_persist_inputs():
242
194
beta_vv = beta_rv .type ()
243
195
y_vv = Y_rv .clone ()
244
196
245
- logp = joint_logprob ({beta_rv : beta_vv , Y_rv : y_vv })
197
+ logp = factorized_joint_logprob ({beta_rv : beta_vv , Y_rv : y_vv })
198
+ logp_combined = pt .sum ([pt .sum (factor ) for factor in logp .values ()])
246
199
247
- assert x in ancestors ([logp ])
200
+ assert x in ancestors ([logp_combined ])
248
201
249
202
# Make sure we don't clone value variables when they're graphs.
250
203
y_vv_2 = y_vv * 2
251
- logp_2 = joint_logprob ({beta_rv : beta_vv , Y_rv : y_vv_2 })
204
+ logp_2 = factorized_joint_logprob ({beta_rv : beta_vv , Y_rv : y_vv_2 })
205
+ logp_2_combined = pt .sum ([pt .sum (factor ) for factor in logp_2 .values ()])
252
206
253
- assert y_vv in ancestors ([logp_2 ])
254
- assert y_vv_2 in ancestors ([logp_2 ])
207
+ assert y_vv in ancestors ([logp_2_combined ])
208
+ assert y_vv_2 in ancestors ([logp_2_combined ])
255
209
256
210
# Even when they are random
257
211
y_vv = pt .random .normal (name = "y_vv2" )
258
212
y_vv_2 = y_vv * 2
259
- logp_2 = joint_logprob ({beta_rv : beta_vv , Y_rv : y_vv_2 })
213
+ logp_2 = factorized_joint_logprob ({beta_rv : beta_vv , Y_rv : y_vv_2 })
214
+ logp_2_combined = pt .sum ([pt .sum (factor ) for factor in logp_2 .values ()])
260
215
261
- assert y_vv in ancestors ([logp_2 ])
262
- assert y_vv_2 in ancestors ([logp_2 ])
216
+ assert y_vv in ancestors ([logp_2_combined ])
217
+ assert y_vv_2 in ancestors ([logp_2_combined ])
263
218
264
219
265
220
def test_warn_random_found_factorized_joint_logprob ():
@@ -284,7 +239,7 @@ def test_multiple_rvs_to_same_value_raises():
284
239
285
240
msg = "More than one logprob factor was assigned to the value var x"
286
241
with pytest .raises (ValueError , match = msg ):
287
- joint_logprob ({x_rv1 : x , x_rv2 : x })
242
+ factorized_joint_logprob ({x_rv1 : x , x_rv2 : x })
288
243
289
244
290
245
def test_joint_logp_basic ():
0 commit comments