Skip to content

Commit 313e007

Browse files
Introduce Model.initial_values and deprecate testval in favor of initval
1 parent 26e5235 commit 313e007

File tree

8 files changed

+135
-56
lines changed

8 files changed

+135
-56
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
- ⚠ Theano-PyMC has been replaced with Aesara, so all external references to `theano`, `tt`, and `pymc3.theanof` need to be replaced with `aesara`, `at`, and `pymc3.aesaraf` (see [4471](https://github.com/pymc-devs/pymc3/pull/4471)).
66
- ArviZ `plots` and `stats` *wrappers* were removed. The functions are now just available by their original names (see [#4549](https://github.com/pymc-devs/pymc3/pull/4471) and `3.11.2` release notes).
77
- The GLM submodule has been removed, please use [Bambi](https://bambinos.github.io/bambi/) instead.
8+
- The `Distribution` keyword argument `testval` has been deprecated in favor of `initval`.
89
- ...
910

1011
### New Features

pymc3/distributions/continuous.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -332,10 +332,12 @@ class Flat(Continuous):
332332
rv_op = flat
333333

334334
@classmethod
335-
def dist(cls, *, size=None, testval=None, **kwargs):
336-
if testval is None:
337-
testval = np.full(size, floatX(0.0))
338-
return super().dist([], size=size, testval=testval, **kwargs)
335+
def dist(cls, *, size=None, initval=None, **kwargs):
336+
if initval is None:
337+
initval = np.full(size, floatX(0.0))
338+
res = super().dist([], size=size, **kwargs)
339+
res.tag.test_value = initval
340+
return res
339341

340342
def logp(value):
341343
"""
@@ -394,10 +396,12 @@ class HalfFlat(PositiveContinuous):
394396
rv_op = halfflat
395397

396398
@classmethod
397-
def dist(cls, *, size=None, testval=None, **kwargs):
398-
if testval is None:
399-
testval = np.full(size, floatX(1.0))
400-
return super().dist([], size=size, testval=testval, **kwargs)
399+
def dist(cls, *, size=None, initval=None, **kwargs):
400+
if initval is None:
401+
initval = np.full(size, floatX(1.0))
402+
res = super().dist([], size=size, **kwargs)
403+
res.tag.test_value = initval
404+
return res
401405

402406
def logp(value):
403407
"""

pymc3/distributions/distribution.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,21 +143,44 @@ def __new__(cls, name, *args, **kwargs):
143143
if "shape" in kwargs:
144144
raise DeprecationWarning("The `shape` keyword is deprecated; use `size`.")
145145

146+
testval = kwargs.pop("testval", None)
147+
148+
if testval is not None:
149+
warnings.warn(
150+
"The `testval` argument is deprecated; use `initval`.",
151+
DeprecationWarning,
152+
stacklevel=2,
153+
)
154+
155+
initval = kwargs.pop("initval", testval)
156+
146157
transform = kwargs.pop("transform", UNSET)
147158

148159
rv_out = cls.dist(*args, rng=rng, **kwargs)
149160

150-
return model.register_rv(rv_out, name, data, total_size, dims=dims, transform=transform)
161+
if testval is not None:
162+
rv_out.tag.test_value = testval
163+
164+
return model.register_rv(
165+
rv_out, name, data, total_size, dims=dims, transform=transform, initval=initval
166+
)
151167

152168
@classmethod
153169
def dist(cls, dist_params, rng=None, **kwargs):
154170

155171
testval = kwargs.pop("testval", None)
156172

157-
rv_var = cls.rv_op(*dist_params, rng=rng, **kwargs)
158-
159173
if testval is not None:
160-
rv_var.tag.test_value = testval
174+
warnings.warn(
175+
"The `testval` argument is deprecated. "
176+
"Use `initval` to set initial values for a `Model`; "
177+
"otherwise, set test values on Aesara parameters explicitly "
178+
"when attempting to use Aesara's test value debugging features.",
179+
DeprecationWarning,
180+
stacklevel=2,
181+
)
182+
183+
rv_var = cls.rv_op(*dist_params, rng=rng, **kwargs)
161184

162185
if (
163186
rv_var.owner

pymc3/model.py

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from aesara.compile.sharedvalue import SharedVariable
4141
from aesara.gradient import grad
4242
from aesara.graph.basic import Constant, Variable, graph_inputs
43-
from aesara.graph.fg import FunctionGraph, MissingInputError
43+
from aesara.graph.fg import FunctionGraph
4444
from aesara.tensor.random.opt import local_subtensor_rv_lift
4545
from aesara.tensor.random.var import RandomStateSharedVariable
4646
from aesara.tensor.sharedvar import ScalarSharedVariable
@@ -572,7 +572,7 @@ def __init__(self, mean=0, sigma=1, name='', model=None):
572572
Normal('v2', mu=mean, sigma=sd)
573573
574574
# something more complex is allowed, too
575-
half_cauchy = HalfCauchy('sd', beta=10, testval=1.)
575+
half_cauchy = HalfCauchy('sd', beta=10, initval=1.)
576576
Normal('v3', mu=mean, sigma=half_cauchy)
577577
578578
# Deterministic variables can be used in usual way
@@ -649,6 +649,7 @@ def __init__(
649649

650650
# The sequence of model-generated RNGs
651651
self.rng_seq = []
652+
self.initial_values = {}
652653

653654
if self.parent is not None:
654655
self.named_vars = treedict(parent=self.parent.named_vars)
@@ -914,35 +915,7 @@ def test_point(self):
914915

915916
@property
916917
def initial_point(self):
917-
points = []
918-
for rv_var in self.free_RVs:
919-
value_var = rv_var.tag.value_var
920-
var_value = getattr(value_var.tag, "test_value", None)
921-
922-
if var_value is None:
923-
924-
rv_var_value = getattr(rv_var.tag, "test_value", None)
925-
926-
if rv_var_value is None:
927-
try:
928-
rv_var_value = rv_var.eval()
929-
except MissingInputError:
930-
raise MissingInputError(f"Couldn't generate an initial value for {rv_var}")
931-
932-
transform = getattr(value_var.tag, "transform", None)
933-
934-
if transform:
935-
try:
936-
rv_var_value = transform.forward(rv_var, rv_var_value).eval()
937-
except MissingInputError:
938-
raise MissingInputError(f"Couldn't generate an initial value for {rv_var}")
939-
940-
var_value = rv_var_value
941-
value_var.tag.test_value = var_value
942-
943-
points.append((value_var, var_value))
944-
945-
return Point(points, model=self)
918+
return Point(list(self.initial_values.items()), model=self)
946919

947920
@property
948921
def disc_vars(self):
@@ -954,6 +927,37 @@ def cont_vars(self):
954927
"""All the continuous variables in the model"""
955928
return list(typefilter(self.value_vars, continuous_types))
956929

930+
def set_initval(self, rv_var, initval):
931+
initval = (
932+
rv_var.type.filter(initval)
933+
if initval is not None
934+
else getattr(rv_var.tag, "test_value", None)
935+
)
936+
937+
rv_value_var = self.rvs_to_values[rv_var]
938+
transform = getattr(rv_value_var.tag, "transform", None)
939+
940+
if initval is None or transform:
941+
# Sample/evaluate this using the existing initial values, and
942+
# with the least amount of affect on the RNGs involved (i.e. no
943+
# in-placing)
944+
from aesara.compile.mode import Mode, get_mode
945+
946+
mode = get_mode(None)
947+
opt_qry = mode.provided_optimizer.excluding("random_make_inplace")
948+
mode = Mode(linker=mode.linker, optimizer=opt_qry)
949+
950+
if transform:
951+
value = initval if initval is not None else rv_var
952+
rv_var = transform.forward(rv_var, value)
953+
954+
initval_fn = aesara.function(
955+
[], rv_var, mode=mode, givens=self.initial_values, on_unused_input="ignore"
956+
)
957+
initval = initval_fn()
958+
959+
self.initial_values[rv_value_var] = initval
960+
957961
def next_rng(self) -> RandomStateSharedVariable:
958962
"""Generate a new ``RandomStateSharedVariable``.
959963
@@ -1116,7 +1120,9 @@ def set_data(
11161120

11171121
shared_object.set_value(values)
11181122

1119-
def register_rv(self, rv_var, name, data=None, total_size=None, dims=None, transform=UNSET):
1123+
def register_rv(
1124+
self, rv_var, name, data=None, total_size=None, dims=None, transform=UNSET, initval=None
1125+
):
11201126
"""Register an (un)observed random variable with the model.
11211127
11221128
Parameters
@@ -1132,6 +1138,8 @@ def register_rv(self, rv_var, name, data=None, total_size=None, dims=None, trans
11321138
Dimension names for the variable.
11331139
transform
11341140
A transform for the random variable in log-likelihood space.
1141+
initval
1142+
The initial value of the random variable.
11351143
11361144
Returns
11371145
-------
@@ -1145,6 +1153,7 @@ def register_rv(self, rv_var, name, data=None, total_size=None, dims=None, trans
11451153
self.free_RVs.append(rv_var)
11461154
self.create_value_var(rv_var, transform)
11471155
self.add_random_variable(rv_var, dims)
1156+
self.set_initval(rv_var, initval)
11481157
else:
11491158
if (
11501159
isinstance(data, Variable)

pymc3/tests/test_distributions.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,7 @@ def check_logp(
594594
n_samples=100,
595595
extra_args=None,
596596
scipy_args=None,
597+
skip_params_fn=lambda x: False,
597598
):
598599
"""
599600
Generic test for PyMC3 logp methods
@@ -625,6 +626,9 @@ def check_logp(
625626
the pymc3 distribution logp is calculated
626627
scipy_args : Dictionary with extra arguments needed to call scipy logp method
627628
Usually the same as extra_args
629+
skip_params_fn: Callable
630+
A function that takes a ``dict`` of the test points and returns a
631+
boolean indicating whether or not to perform the test.
628632
"""
629633
if decimal is None:
630634
decimal = select_by_precision(float64=6, float32=3)
@@ -646,6 +650,8 @@ def logp_reference(args):
646650
domains["value"] = domain
647651
for pt in product(domains, n_samples=n_samples):
648652
pt = dict(pt)
653+
if skip_params_fn(pt):
654+
continue
649655
pt_d = self._model_input_dict(model, param_vars, pt)
650656
pt_logp = Point(pt_d, model=model)
651657
pt_ref = Point(pt, filter_model_vars=False, model=model)
@@ -690,6 +696,7 @@ def check_logcdf(
690696
n_samples=100,
691697
skip_paramdomain_inside_edge_test=False,
692698
skip_paramdomain_outside_edge_test=False,
699+
skip_params_fn=lambda x: False,
693700
):
694701
"""
695702
Generic test for PyMC3 logcdf methods
@@ -730,6 +737,9 @@ def check_logcdf(
730737
skip_paramdomain_outside_edge_test : Bool
731738
Whether to run test 2., which checks that pymc3 distribution logcdf
732739
returns -inf for invalid parameter values outside the supported domain edge
740+
skip_params_fn: Callable
741+
A function that takes a ``dict`` of the test points and returns a
742+
boolean indicating whether or not to perform the test.
733743
734744
Returns
735745
-------
@@ -745,6 +755,8 @@ def check_logcdf(
745755

746756
for pt in product(domains, n_samples=n_samples):
747757
params = dict(pt)
758+
if skip_params_fn(params):
759+
continue
748760
scipy_cdf = scipy_logcdf(**params)
749761
value = params.pop("value")
750762
with Model() as m:
@@ -825,7 +837,13 @@ def check_logcdf(
825837
)
826838

827839
def check_selfconsistency_discrete_logcdf(
828-
self, distribution, domain, paramdomains, decimal=None, n_samples=100
840+
self,
841+
distribution,
842+
domain,
843+
paramdomains,
844+
decimal=None,
845+
n_samples=100,
846+
skip_params_fn=lambda x: False,
829847
):
830848
"""
831849
Check that logcdf of discrete distributions matches sum of logps up to value
@@ -836,6 +854,8 @@ def check_selfconsistency_discrete_logcdf(
836854
decimal = select_by_precision(float64=6, float32=3)
837855
for pt in product(domains, n_samples=n_samples):
838856
params = dict(pt)
857+
if skip_params_fn(params):
858+
continue
839859
value = params.pop("value")
840860
values = np.arange(domain.lower, value + 1)
841861
dist = distribution.dist(**params)
@@ -1187,17 +1207,20 @@ def modified_scipy_hypergeom_logcdf(value, N, k, n):
11871207
Nat,
11881208
{"N": NatSmall, "k": NatSmall, "n": NatSmall},
11891209
modified_scipy_hypergeom_logpmf,
1210+
skip_params_fn=lambda x: x["N"] < x["n"] or x["N"] < x["k"],
11901211
)
11911212
self.check_logcdf(
11921213
HyperGeometric,
11931214
Nat,
11941215
{"N": NatSmall, "k": NatSmall, "n": NatSmall},
11951216
modified_scipy_hypergeom_logcdf,
1217+
skip_params_fn=lambda x: x["N"] < x["n"] or x["N"] < x["k"],
11961218
)
11971219
self.check_selfconsistency_discrete_logcdf(
11981220
HyperGeometric,
11991221
Nat,
12001222
{"N": NatSmall, "k": NatSmall, "n": NatSmall},
1223+
skip_params_fn=lambda x: x["N"] < x["n"] or x["N"] < x["k"],
12011224
)
12021225

12031226
def test_negative_binomial(self):

pymc3/tests/test_distributions_random.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -327,17 +327,13 @@ def test_distribution(self):
327327

328328
def _instantiate_pymc_rv(self, dist_params=None):
329329
params = dist_params if dist_params else self.pymc_dist_params
330-
with pm.Model():
331-
self.pymc_rv = self.pymc_dist(
332-
**params,
333-
size=self.size,
334-
rng=aesara.shared(self.get_random_state(reset=True)),
335-
name=f"{self.pymc_dist.rv_op.name}_test",
336-
)
330+
self.pymc_rv = self.pymc_dist.dist(
331+
**params, size=self.size, rng=aesara.shared(self.get_random_state(reset=True))
332+
)
337333

338334
def check_pymc_draws_match_reference(self):
339335
# need to re-instantiate it to make sure that the order of drawings match the reference distribution one
340-
self._instantiate_pymc_rv()
336+
# self._instantiate_pymc_rv()
341337
assert_array_almost_equal(
342338
self.pymc_rv.eval(), self.reference_dist_draws, decimal=self.decimal
343339
)

pymc3/tests/test_logp.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import scipy.stats.distributions as sp
1919

2020
from aesara.gradient import DisconnectedGrad
21-
from aesara.graph.basic import Constant, graph_inputs
21+
from aesara.graph.basic import Constant, ancestors, graph_inputs
2222
from aesara.graph.fg import FunctionGraph
2323
from aesara.tensor.random.op import RandomVariable
2424
from aesara.tensor.subtensor import (
@@ -38,6 +38,11 @@
3838
from pymc3.tests.helpers import select_by_precision
3939

4040

41+
def assert_no_rvs(var):
42+
assert not any(isinstance(v.owner.op, RandomVariable) for v in ancestors([var]) if v.owner)
43+
return var
44+
45+
4146
def test_logpt_basic():
4247
"""Make sure we can compute a log-likelihood for a hierarchical model with transforms."""
4348

@@ -171,7 +176,7 @@ def test_logpt_subtensor():
171176
logp_vals_fn = aesara.function([A_idx_value_var, I_value_var], A_idx_logp)
172177

173178
# The compiled graph should not contain any `RandomVariables`
174-
assert not any(isinstance(n.op, RandomVariable) for n in logp_vals_fn.maker.fgraph.apply_nodes)
179+
assert_no_rvs(logp_vals_fn.maker.fgraph.outputs[0])
175180

176181
decimals = select_by_precision(float64=6, float32=4)
177182

0 commit comments

Comments
 (0)