@@ -168,10 +168,12 @@ def __init__(
168
168
vars = model .value_vars
169
169
else :
170
170
vars = [model .rvs_to_values .get (var , var ) for var in vars ]
171
+
171
172
vars = pm .inputvars (vars )
172
173
174
+ initial_values_shape = [initial_values [v .name ].shape for v in vars ]
173
175
if S is None :
174
- S = np .ones (sum (initial_values [ v . name ]. size for v in vars ))
176
+ S = np .ones (int ( sum (np . prod ( ivs ) for ivs in initial_values_shape ) ))
175
177
176
178
if proposal_dist is not None :
177
179
self .proposal_dist = proposal_dist (S )
@@ -186,7 +188,6 @@ def __init__(
186
188
self .tune = tune
187
189
self .tune_interval = tune_interval
188
190
self .steps_until_tune = tune_interval
189
- self .accepted = 0
190
191
191
192
# Determine type of variables
192
193
self .discrete = np .concatenate (
@@ -195,11 +196,33 @@ def __init__(
195
196
self .any_discrete = self .discrete .any ()
196
197
self .all_discrete = self .discrete .all ()
197
198
198
- # remember initial settings before tuning so they can be reset
199
- self ._untuned_settings = dict (
200
- scaling = self .scaling , steps_until_tune = tune_interval , accepted = self .accepted
199
+ # Metropolis will try to handle one batched dimension at a time This, however,
200
+ # is not safe for discrete multivariate distributions (looking at you Multinomial),
201
+ # due to high dependency among the support dimensions. For continuous multivariate
202
+ # distributions we assume they are being transformed in a way that makes each
203
+ # dimension semi-independent.
204
+ is_scalar = len (initial_values_shape ) == 1 and initial_values_shape [0 ] == ()
205
+ self .elemwise_update = not (
206
+ is_scalar
207
+ or (
208
+ self .any_discrete
209
+ and max (getattr (model .values_to_rvs [var ].owner .op , "ndim_supp" , 1 ) for var in vars )
210
+ > 0
211
+ )
201
212
)
213
+ if self .elemwise_update :
214
+ dims = int (sum (np .prod (ivs ) for ivs in initial_values_shape ))
215
+ else :
216
+ dims = 1
217
+ self .enum_dims = np .arange (dims , dtype = int )
218
+ self .accept_rate_iter = np .zeros (dims , dtype = float )
219
+ self .accepted_iter = np .zeros (dims , dtype = bool )
220
+ self .accepted_sum = np .zeros (dims , dtype = int )
202
221
222
+ # remember initial settings before tuning so they can be reset
223
+ self ._untuned_settings = dict (scaling = self .scaling , steps_until_tune = tune_interval )
224
+
225
+ # TODO: This is not being used when compiling the logp function!
203
226
self .mode = mode
204
227
205
228
shared = pm .make_shared_replacements (initial_values , vars , model )
@@ -210,6 +233,7 @@ def reset_tuning(self):
210
233
"""Resets the tuned sampler parameters to their initial values."""
211
234
for attr , initial_value in self ._untuned_settings .items ():
212
235
setattr (self , attr , initial_value )
236
+ self .accepted_sum [:] = 0
213
237
return
214
238
215
239
def astep (self , q0 : RaveledVars ) -> Tuple [RaveledVars , List [Dict [str , Any ]]]:
@@ -219,10 +243,10 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
219
243
220
244
if not self .steps_until_tune and self .tune :
221
245
# Tune scaling parameter
222
- self .scaling = tune (self .scaling , self .accepted / float (self .tune_interval ))
246
+ self .scaling = tune (self .scaling , self .accepted_sum / float (self .tune_interval ))
223
247
# Reset counter
224
248
self .steps_until_tune = self .tune_interval
225
- self .accepted = 0
249
+ self .accepted_sum [:] = 0
226
250
227
251
delta = self .proposal_dist () * self .scaling
228
252
@@ -237,23 +261,36 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
237
261
else :
238
262
q = floatX (q0 + delta )
239
263
240
- accept = self .delta_logp (q , q0 )
241
- q_new , accepted = metrop_select (accept , q , q0 )
242
-
243
- self .accepted += accepted
264
+ if self .elemwise_update :
265
+ q_temp = q0 .copy ()
266
+ # Shuffle order of updates (probably we don't need to do this in every step)
267
+ np .random .shuffle (self .enum_dims )
268
+ for i in self .enum_dims :
269
+ q_temp [i ] = q [i ]
270
+ accept_rate_i = self .delta_logp (q_temp , q0 )
271
+ q_temp_ , accepted_i = metrop_select (accept_rate_i , q_temp , q0 )
272
+ q_temp [i ] = q_temp_ [i ]
273
+ self .accept_rate_iter [i ] = accept_rate_i
274
+ self .accepted_iter [i ] = accepted_i
275
+ self .accepted_sum [i ] += accepted_i
276
+ q = q_temp
277
+ else :
278
+ accept_rate = self .delta_logp (q , q0 )
279
+ q , accepted = metrop_select (accept_rate , q , q0 )
280
+ self .accept_rate_iter = accept_rate
281
+ self .accepted_iter = accepted
282
+ self .accepted_sum += accepted
244
283
245
284
self .steps_until_tune -= 1
246
285
247
286
stats = {
248
287
"tune" : self .tune ,
249
- "scaling" : self .scaling ,
250
- "accept" : np .exp (accept ),
251
- "accepted" : accepted ,
288
+ "scaling" : np . mean ( self .scaling ) ,
289
+ "accept" : np .mean ( np . exp (self . accept_rate_iter ) ),
290
+ "accepted" : np . mean ( self . accepted_iter ) ,
252
291
}
253
292
254
- q_new = RaveledVars (q_new , point_map_info )
255
-
256
- return q_new , [stats ]
293
+ return RaveledVars (q , point_map_info ), [stats ]
257
294
258
295
@staticmethod
259
296
def competence (var , has_grad ):
@@ -275,26 +312,38 @@ def tune(scale, acc_rate):
275
312
>0.95 x 10
276
313
277
314
"""
278
- if acc_rate < 0.001 :
315
+ return scale * np .where (
316
+ acc_rate < 0.001 ,
279
317
# reduce by 90 percent
280
- return scale * 0.1
281
- elif acc_rate < 0.05 :
282
- # reduce by 50 percent
283
- return scale * 0.5
284
- elif acc_rate < 0.2 :
285
- # reduce by ten percent
286
- return scale * 0.9
287
- elif acc_rate > 0.95 :
288
- # increase by factor of ten
289
- return scale * 10.0
290
- elif acc_rate > 0.75 :
291
- # increase by double
292
- return scale * 2.0
293
- elif acc_rate > 0.5 :
294
- # increase by ten percent
295
- return scale * 1.1
296
-
297
- return scale
318
+ 0.1 ,
319
+ np .where (
320
+ acc_rate < 0.05 ,
321
+ # reduce by 50 percent
322
+ 0.5 ,
323
+ np .where (
324
+ acc_rate < 0.2 ,
325
+ # reduce by ten percent
326
+ 0.9 ,
327
+ np .where (
328
+ acc_rate > 0.95 ,
329
+ # increase by factor of ten
330
+ 10.0 ,
331
+ np .where (
332
+ acc_rate > 0.75 ,
333
+ # increase by double
334
+ 2.0 ,
335
+ np .where (
336
+ acc_rate > 0.5 ,
337
+ # increase by ten percent
338
+ 1.1 ,
339
+ # Do not change
340
+ 1.0 ,
341
+ ),
342
+ ),
343
+ ),
344
+ ),
345
+ ),
346
+ )
298
347
299
348
300
349
class BinaryMetropolis (ArrayStep ):
0 commit comments