Skip to content

Commit d650afd

Browse files
Fix RaveledVars usage in BinaryGibbsMetropolis.astep
1 parent f3fe8ba commit d650afd

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

pymc3/step_methods/metropolis.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -427,28 +427,24 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None):
427427

428428
def astep(self, q0: RaveledVars, logp) -> RaveledVars:
429429

430-
point_map_info = q0.point_map_info
431-
q0 = q0.data
432-
433430
order = self.order
434431
if self.shuffle_dims:
435432
nr.shuffle(order)
436433

437-
q = np.copy(q0)
434+
q = RaveledVars(np.copy(q0.data), q0.point_map_info)
435+
438436
logp_curr = logp(q)
439437

440438
for idx in order:
441439
# No need to do metropolis update if the same value is proposed,
442440
# as you will get the same value regardless of accepted or reject
443441
if nr.rand() < self.transit_p:
444-
curr_val, q[idx] = q[idx], True - q[idx]
442+
curr_val, q.data[idx] = q.data[idx], True - q.data[idx]
445443
logp_prop = logp(q)
446-
q[idx], accepted = metrop_select(logp_prop - logp_curr, q[idx], curr_val)
444+
q.data[idx], accepted = metrop_select(logp_prop - logp_curr, q.data[idx], curr_val)
447445
if accepted:
448446
logp_curr = logp_prop
449447

450-
q = RaveledVars(q, point_map_info)
451-
452448
return q
453449

454450
@staticmethod

0 commit comments

Comments
 (0)