Skip to content

Commit ea074fe

Browse files
Fix RaveledVars and size-related issues in Metropolis and MLDA samplers
1 parent 7559c66 commit ea074fe

File tree

3 files changed

+36
-30
lines changed

3 files changed

+36
-30
lines changed

pymc3/sampling.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@
100100

101101

102102
def instantiate_steppers(
103-
_model, steps: List[Step], selected_steps, step_kwargs=None
103+
model, steps: List[Step], selected_steps, step_kwargs=None
104104
) -> Union[Step, List[Step]]:
105105
"""Instantiate steppers assigned to the model variables.
106106
@@ -110,7 +110,7 @@ def instantiate_steppers(
110110
Parameters
111111
----------
112112
model : Model object
113-
A fully-specified model object; legacy argument -- ignored
113+
A fully-specified model object
114114
steps : list
115115
A list of zero or more step function instances that have been assigned to some subset of
116116
the model's parameters.
@@ -134,7 +134,7 @@ def instantiate_steppers(
134134
if vars:
135135
args = step_kwargs.get(step_class.name, {})
136136
used_keys.add(step_class.name)
137-
step = step_class(vars=vars, **args)
137+
step = step_class(vars=vars, model=model, **args)
138138
steps.append(step)
139139

140140
unused_args = set(step_kwargs).difference(used_keys)
@@ -600,7 +600,7 @@ def sample(
600600
)
601601
_log.info(f"Population sampling ({chains} chains)")
602602

603-
initial_point_model_size = sum(start[n.name].size for n in model.value_vars)
603+
initial_point_model_size = sum(start[0][n.name].size for n in model.value_vars)
604604

605605
if has_demcmc and chains < 3:
606606
raise ValueError(
@@ -1014,7 +1014,7 @@ def _iter_sample(
10141014
except TypeError:
10151015
pass
10161016

1017-
point = Point(start, model=model)
1017+
point = Point(start, model=model, filter_model_vars=True)
10181018

10191019
if step.generates_stats and strace.supports_sampler_stats:
10201020
strace.setup(draws, chain, step.stats_dtypes)

pymc3/step_methods/metropolis.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import pymc3 as pm
2525

26-
from pymc3.aesaraf import floatX
26+
from pymc3.aesaraf import floatX, rvs_to_value_vars
2727
from pymc3.blocking import DictToArrayBijection, RaveledVars
2828
from pymc3.step_methods.arraystep import (
2929
ArrayStep,
@@ -408,8 +408,8 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None):
408408
# transition probabilities
409409
self.transit_p = transit_p
410410

411-
# XXX: This needs to be refactored
412-
self.dim = None # sum(v.dsize for v in vars)
411+
initial_point = model.initial_point
412+
self.dim = sum(initial_point[v.name].size for v in vars)
413413

414414
if order == "random":
415415
self.shuffle_dims = True
@@ -491,29 +491,35 @@ class CategoricalGibbsMetropolis(ArrayStep):
491491
def __init__(self, vars, proposal="uniform", order="random", model=None):
492492

493493
model = pm.modelcontext(model)
494+
494495
vars = pm.inputvars(vars)
495496

497+
initial_point = model.initial_point
498+
496499
dimcats = []
497500
# The above variable is a list of pairs (aggregate dimension, number
498501
# of categories). For example, if vars = [x, y] with x being a 2-D
499502
# variable with M categories and y being a 3-D variable with N
500503
# categories, we will have dimcats = [(0, M), (1, M), (2, N), (3, N), (4, N)].
501504
for v in vars:
502505

503-
distr = getattr(v.owner, "op", None)
506+
v_init_val = initial_point[v.name]
507+
508+
rv_var = model.values_to_rvs[v]
509+
distr = getattr(rv_var.owner, "op", None)
504510

505511
if isinstance(distr, CategoricalRV):
506-
# XXX: This needs to be refactored
507-
k = None # draw_values([distr.k])[0]
508-
elif isinstance(distr, pm.Bernoulli) or (v.dtype in pm.bool_types):
512+
k_graph = rv_var.owner.inputs[3].shape[-1]
513+
(k_graph,), _ = rvs_to_value_vars((k_graph,), apply_transforms=True)
514+
k = model.fn(k_graph)(initial_point)
515+
elif isinstance(distr, BernoulliRV):
509516
k = 2
510517
else:
511518
raise ValueError(
512519
"All variables must be categorical or binary" + "for CategoricalGibbsMetropolis"
513520
)
514521
start = len(dimcats)
515-
# XXX: This needs to be refactored
516-
dimcats += None # [(dim, k) for dim in range(start, start + v.dsize)]
522+
dimcats += [(dim, k) for dim in range(start, start + v_init_val.size)]
517523

518524
if order == "random":
519525
self.shuffle_dims = True
@@ -543,18 +549,16 @@ def astep_unif(self, q0: RaveledVars, logp) -> RaveledVars:
543549
if self.shuffle_dims:
544550
nr.shuffle(dimcats)
545551

546-
q = np.copy(q0)
552+
q = RaveledVars(np.copy(q0), point_map_info)
547553
logp_curr = logp(q)
548554

549555
for dim, k in dimcats:
550-
curr_val, q[dim] = q[dim], sample_except(k, q[dim])
556+
curr_val, q.data[dim] = q.data[dim], sample_except(k, q.data[dim])
551557
logp_prop = logp(q)
552-
q[dim], accepted = metrop_select(logp_prop - logp_curr, q[dim], curr_val)
558+
q.data[dim], accepted = metrop_select(logp_prop - logp_curr, q.data[dim], curr_val)
553559
if accepted:
554560
logp_curr = logp_prop
555561

556-
q = RaveledVars(q, point_map_info)
557-
558562
return q
559563

560564
def astep_prop(self, q0: RaveledVars, logp) -> RaveledVars:
@@ -566,34 +570,32 @@ def astep_prop(self, q0: RaveledVars, logp) -> RaveledVars:
566570
if self.shuffle_dims:
567571
nr.shuffle(dimcats)
568572

569-
q = np.copy(q0)
573+
q = RaveledVars(np.copy(q0), point_map_info)
570574
logp_curr = logp(q)
571575

572576
for dim, k in dimcats:
573577
logp_curr = self.metropolis_proportional(q, logp, logp_curr, dim, k)
574578

575-
q = RaveledVars(q, point_map_info)
576-
577579
return q
578580

579581
def metropolis_proportional(self, q, logp, logp_curr, dim, k):
580-
given_cat = int(q[dim])
582+
given_cat = int(q.data[dim])
581583
log_probs = np.zeros(k)
582584
log_probs[given_cat] = logp_curr
583585
candidates = list(range(k))
584586
for candidate_cat in candidates:
585587
if candidate_cat != given_cat:
586-
q[dim] = candidate_cat
588+
q.data[dim] = candidate_cat
587589
log_probs[candidate_cat] = logp(q)
588590
probs = softmax(log_probs)
589591
prob_curr, probs[given_cat] = probs[given_cat], 0.0
590592
probs /= 1.0 - prob_curr
591593
proposed_cat = nr.choice(candidates, p=probs)
592594
accept_ratio = (1.0 - prob_curr) / (1.0 - probs[proposed_cat])
593595
if not np.isfinite(accept_ratio) or nr.uniform() >= accept_ratio:
594-
q[dim] = given_cat
596+
q.data[dim] = given_cat
595597
return logp_curr
596-
q[dim] = proposed_cat
598+
q.data[dim] = proposed_cat
597599
return log_probs[proposed_cat]
598600

599601
@staticmethod
@@ -744,7 +746,7 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
744746
r1 = DictToArrayBijection.map(self.population[ir1])
745747
r2 = DictToArrayBijection.map(self.population[ir2])
746748
# propose a jump
747-
q = floatX(q0 + self.lamb * (r1 - r2) + epsilon)
749+
q = floatX(q0 + self.lamb * (r1.data - r2.data) + epsilon)
748750

749751
accept = self.delta_logp(q, q0)
750752
q_new, accepted = metrop_select(accept, q, q0)

pymc3/step_methods/mlda.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import pymc3 as pm
2727

2828
from pymc3.blocking import DictToArrayBijection
29-
from pymc3.model import Model
29+
from pymc3.model import Model, Point
3030
from pymc3.step_methods.arraystep import ArrayStepShared, Competence, metrop_select
3131
from pymc3.step_methods.compound import CompoundStep
3232
from pymc3.step_methods.metropolis import (
@@ -746,7 +746,8 @@ def astep(self, q0):
746746

747747
# Call the recursive DA proposal to get proposed sample
748748
# and convert dict -> numpy array
749-
q = DictToArrayBijection.map(self.proposal_dist(q0_dict))
749+
pre_q = self.proposal_dist(q0_dict)
750+
q = DictToArrayBijection.map(pre_q)
750751

751752
# Evaluate MLDA acceptance log-ratio
752753
# If proposed sample from lower levels is the same as current one,
@@ -1141,4 +1142,7 @@ def __call__(self, q0_dict: dict) -> dict:
11411142
# return sample with index self.subchain_selection from the generated
11421143
# sequence of length self.subsampling_rate. The index is set within
11431144
# MLDA's astep() function
1144-
return self.trace.point(-self.subsampling_rate + self.subchain_selection)
1145+
new_point = self.trace.point(-self.subsampling_rate + self.subchain_selection)
1146+
new_point = Point(new_point, model=self.model_below, filter_model_vars=True)
1147+
1148+
return new_point

0 commit comments

Comments
 (0)