Skip to content

Commit 48f4122

Browse files
brandonwillardtwiecki
authored andcommitted
Use the correct givens in Model.set_initval
1 parent 5c58640 commit 48f4122

File tree

2 files changed

+39
-3
lines changed

2 files changed

+39
-3
lines changed

pymc3/model.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -938,8 +938,7 @@ def set_initval(self, rv_var, initval):
938938

939939
if initval is None or transform:
940940
# Sample/evaluate this using the existing initial values, and
941-
# with the least amount of affect on the RNGs involved (i.e. no
942-
# in-placing)
941+
# with the least effect on the RNGs involved (i.e. no in-placing)
943942
from aesara.compile.mode import Mode, get_mode
944943

945944
mode = get_mode(None)
@@ -950,8 +949,21 @@ def set_initval(self, rv_var, initval):
950949
value = initval if initval is not None else rv_var
951950
rv_var = transform.forward(rv_var, value)
952951

952+
def initval_to_rvval(value_var, value):
953+
rv_var = self.values_to_rvs[value_var]
954+
initval = value_var.type.make_constant(value)
955+
transform = getattr(value_var.tag, "transform", None)
956+
if transform:
957+
return transform.backward(rv_var, initval)
958+
else:
959+
return initval
960+
961+
givens = {
962+
self.values_to_rvs[k]: initval_to_rvval(k, v)
963+
for k, v in self.initial_values.items()
964+
}
953965
initval_fn = aesara.function(
954-
[], rv_var, mode=mode, givens=self.initial_values, on_unused_input="ignore"
966+
[], rv_var, mode=mode, givens=givens, on_unused_input="ignore"
955967
)
956968
initval = initval_fn()
957969

pymc3/tests/test_model.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,3 +632,27 @@ def test_invalid_variable_name(self):
632632
model.update_start_vals(start, model.initial_point)
633633
with pytest.raises(KeyError):
634634
model.check_start_vals(start)
635+
636+
637+
def test_set_initval():
638+
# Make sure the dependencies between variables are maintained when
639+
# generating initial values
640+
rng = np.random.RandomState(392)
641+
642+
with pm.Model(rng_seeder=rng) as model:
643+
eta = pm.Uniform("eta", 1.0, 2.0, size=(1, 1))
644+
mu = pm.Normal("mu", sd=eta, initval=[[100]])
645+
alpha = pm.HalfNormal("alpha", initval=100)
646+
value = pm.NegativeBinomial("value", mu=mu, alpha=alpha)
647+
648+
assert np.array_equal(model.initial_values[model.rvs_to_values[mu]], np.array([[100.0]]))
649+
np.testing.assert_almost_equal(model.initial_values[model.rvs_to_values[alpha]], np.log(100))
650+
assert 50 < model.initial_values[model.rvs_to_values[value]] < 150
651+
652+
# `Flat` cannot be sampled, so let's make sure that doesn't break initial
653+
# value computations
654+
with pm.Model() as model:
655+
x = pm.Flat("x")
656+
y = pm.Normal("y", x, 1)
657+
658+
assert model.rvs_to_values[y] in model.initial_values

0 commit comments

Comments
 (0)