diff --git a/pymc3/distributions/continuous.py b/pymc3/distributions/continuous.py index 7e33ed28f3..b6f15345ba 100644 --- a/pymc3/distributions/continuous.py +++ b/pymc3/distributions/continuous.py @@ -87,7 +87,7 @@ def polyagamma_cdf(*args, **kwargs): ) from pymc3.distributions.distribution import Continuous from pymc3.math import logdiffexp, logit -from pymc3.util import UNSET +from pymc3.util import UNSET, select_initval __all__ = [ "Uniform", @@ -366,6 +366,18 @@ def dist(cls, *, size=None, **kwargs): res.tag.test_value = np.full(size, floatX(0.0)) return res + @classmethod + def pick_initval(cls, *, size=None, initval=UNSET, **kwargs) -> np.ndarray: + initval = select_initval( + candidate=initval, + default=np.full(size, floatX(0.0)), + ) + if initval is None: + raise NotImplementedError( + "The `Flat` distribution does not support random initval sampling (`initval=None`)." + ) + return initval + def logp(value): """ Calculate log-probability of Flat distribution at specified value. @@ -428,6 +440,18 @@ def dist(cls, *, size=None, **kwargs): res.tag.test_value = np.full(size, floatX(1.0)) return res + @classmethod + def pick_initval(cls, *, size=None, initval=UNSET, **kwargs) -> np.ndarray: + initval = select_initval( + candidate=initval, + default=np.full(size, floatX(1.0)), + ) + if initval is None: + raise NotImplementedError( + "The `HalfFlat` distribution does not support random initval sampling (`initval=None`)." + ) + return initval + def logp(value): """ Calculate log-probability of HalfFlat distribution at specified value. diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index 1d15c46deb..fd5c5df00a 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -19,11 +19,13 @@ import warnings from abc import ABCMeta -from typing import Optional +from typing import Optional, Union import aesara import aesara.tensor as at +import numpy as np +from aesara.graph.basic import Constant, Variable from aesara.tensor.random.op import RandomVariable from aesara.tensor.random.var import RandomStateSharedVariable @@ -42,7 +44,7 @@ resize_from_observed, ) from pymc3.printing import str_for_dist -from pymc3.util import UNSET +from pymc3.util import UNSET, select_initval from pymc3.vartypes import string_types __all__ = [ @@ -126,7 +128,7 @@ def __new__( *args, rng=None, dims: Optional[Dims] = None, - initval=None, + initval=UNSET, observed=None, total_size=None, transform=UNSET, @@ -149,6 +151,9 @@ def __new__( initval : optional Test value to be attached to the output RV. Must match its shape exactly. + If set to `None`, an initial value will be drawn randomly. + With a value of `UNSET`, or not passing `initval` the `cls` distribution + may provide a default initial value at its own discretion. observed : optional Observed data to be passed when registering the random variable in the model. See ``Model.register_rv``. @@ -219,6 +224,10 @@ def __new__( # A batch size was specified through `dims`, or implied by `observed`. rv_out = change_rv_size(rv_var=rv_out, new_size=resize_shape, expand=True) + initval = select_initval( + candidate=initval, + default=cls.pick_initval(*args, rng=rng, **kwargs), + ) rv_out = model.register_rv( rv_out, name, @@ -327,6 +336,14 @@ def dist( return rv_out + @classmethod + def pick_initval(cls, *args, **kwargs) -> Union[int, float, np.ndarray, Constant, Variable]: + """Fallback method for creating an initial value for a random variable. + + Parameters are the same as for the `.dist()` method. + """ + return None + class NoDistribution(Distribution): def __init__( diff --git a/pymc3/model.py b/pymc3/model.py index df3c486338..23a3dd28e9 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -944,15 +944,13 @@ def cont_vars(self): return list(typefilter(self.value_vars, continuous_types)) def set_initval(self, rv_var, initval): - if initval is not None: + if not isinstance(initval, (type(None), TensorVariable)): initval = rv_var.type.filter(initval) - test_value = getattr(rv_var.tag, "test_value", None) - rv_value_var = self.rvs_to_values[rv_var] transform = getattr(rv_value_var.tag, "transform", None) - if initval is None or transform: + if transform or isinstance(initval, (type(None), TensorVariable)): # Sample/evaluate this using the existing initial values, and # with the least effect on the RNGs involved (i.e. no in-placing) from aesara.compile.mode import Mode, get_mode @@ -981,17 +979,7 @@ def initval_to_rvval(value_var, value): initval_fn = aesara.function( [], rv_var, mode=mode, givens=givens, on_unused_input="ignore" ) - try: - initval = initval_fn() - except NotImplementedError as ex: - if "Cannot sample from" in ex.args[0]: - # The RV does not have a random number generator. - # Our last chance is to take the test_value. - # Note that this is a workaround for Flat and HalfFlat - # until an initval default mechanism is implemented (#4752). - initval = test_value - else: - raise + initval = initval_fn() self.initial_values[rv_value_var] = initval diff --git a/pymc3/tests/test_initvals.py b/pymc3/tests/test_initvals.py index f445c88db0..b9745e1def 100644 --- a/pymc3/tests/test_initvals.py +++ b/pymc3/tests/test_initvals.py @@ -11,10 +11,69 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import aesara +import aesara.tensor as at +import aesara.tensor.random.basic as atr +import numpy as np +import pandas as pd import pytest +from aesara.graph.basic import Variable + import pymc3 as pm +from pymc3.util import UNSET, select_initval + + +def test_util_select_initval(): + # The candidate is preferred, particularly if the default is invalid + assert select_initval(4, default=None) == 4 + assert select_initval([1, 2], default=None) == [1, 2] + assert select_initval(None, default=3) is None + assert select_initval(UNSET, default=pd.Series([1, 2, 3])) is UNSET + + # The default is preferred if the candidate is UNSET or invalid + assert select_initval(UNSET, default=3) == 3 + assert select_initval(UNSET, default=None) is None + assert select_initval(pd.Series([1, 2, 3]), default=None) is None + assert isinstance(select_initval(UNSET, default=at.scalar()), Variable) + assert isinstance(select_initval(at.scalar(), default=3), Variable) + + # None is the fallback if both are invalid + assert select_initval(pd.Series([1, 2, 3]), default="not good") is None + pass + + +class NormalWithoutInitval(pm.Distribution): + """A distribution that does not specify a default initial value.""" + + rv_op = atr.normal + + @classmethod + def dist(cls, mu=0, sigma=None, **kwargs): + mu = at.as_tensor(pm.floatX(mu)) + sigma = at.as_tensor(pm.floatX(sigma)) + return super().dist([mu, sigma], **kwargs) + + +class UniformWithInitval(pm.distributions.continuous.BoundedContinuous): + """ + A distribution that defaults the initial value. + """ + + rv_op = atr.uniform + bound_args_indices = (0, 1) # Lower, Upper + + @classmethod + def dist(cls, lower=0, upper=1, initval=UNSET, **kwargs): + lower = at.as_tensor_variable(pm.floatX(lower)) + upper = at.as_tensor_variable(pm.floatX(upper)) + return super().dist([lower, upper], **kwargs) + + @classmethod + def pick_initval(cls, lower, upper, **kwargs): + return (lower + upper) / 2 + def transform_fwd(rv, expected_untransformed): return rv.tag.value_var.tag.transform.forward(rv, expected_untransformed).eval() @@ -22,6 +81,9 @@ def transform_fwd(rv, expected_untransformed): class TestInitvalAssignment: def test_dist_warnings_and_errors(self): + rv = UniformWithInitval.dist(1, 2) + assert not hasattr(rv.tag, "test_value") + with pytest.warns(DeprecationWarning, match="argument is deprecated and has no effect"): rv = pm.Exponential.dist(lam=1, testval=0.5) assert not hasattr(rv.tag, "test_value") @@ -38,8 +100,66 @@ def test_new_warnings(self): assert not hasattr(rv.tag, "test_value") pass + def test_new_initval_behaviors(self): + """ + No test values are set on the RV unless specified by either the user or the RV Op. + But initial values are always determined and managed by the Model object. + """ + with pm.Model() as pmodel: + rv1 = NormalWithoutInitval("default to random draw", 1, 2) + rv2 = NormalWithoutInitval("default to random draw the second", 1, 2) + assert pmodel.initial_values[rv1.tag.value_var] != 1 + assert pmodel.initial_values[rv2.tag.value_var] != 1 + assert ( + pmodel.initial_values[rv1.tag.value_var] != pmodel.initial_values[rv2.tag.value_var] + ) + # Randomly drawn initvals are not attached to the rv: + assert not hasattr(rv1.tag, "test_value") + assert not hasattr(rv2.tag, "test_value") + + rv = NormalWithoutInitval("user provided", 1, 2, initval=-0.2) + assert pmodel.initial_values[rv.tag.value_var] == np.array( + -0.2, dtype=aesara.config.floatX + ) + assert not hasattr(rv.tag, "test_value") + + rv = UniformWithInitval("RVOp default", 1.5, 2) + assert pmodel.initial_values[rv.tag.value_var] == transform_fwd(rv, 1.75) + assert not hasattr(rv.tag, "test_value") + + rv = UniformWithInitval("user can override RVOp default", 1.5, 2, initval=1.8) + assert pmodel.initial_values[rv.tag.value_var] == transform_fwd(rv, 1.8) + assert not hasattr(rv.tag, "test_value") + + rv = UniformWithInitval("user can revert to random draw", 1.5, 2, initval=None) + assert pmodel.initial_values[rv.tag.value_var] != transform_fwd(rv, 1.75) + assert not hasattr(rv.tag, "test_value") + pass + + def test_symbolic_initval(self): + """A regression tests for https://github.com/pymc-devs/pymc3/issues/4911""" + with pm.Model() as pmodel: + a = pm.Normal("a") + b = pm.Normal("b", a, initval=a) + # Initval assignment should evaluate symbolics: + assert isinstance(pmodel.initial_point["b"], np.ndarray) + class TestSpecialDistributions: + def test_flat(self): + pm.Flat.pick_initval(initval=4) == 4 + pm.Flat.pick_initval(size=(2,), initval=UNSET) == np.array([0, 0]) + with pytest.raises(NotImplementedError, match="does not support random initval"): + pm.Flat.pick_initval(initval=None) + pass + + def test_halfflat(self): + pm.HalfFlat.pick_initval(initval=4) == 4 + pm.HalfFlat.pick_initval(size=(2,), initval=UNSET) == np.array([1, 1]) + with pytest.raises(NotImplementedError, match="does not support random initval"): + pm.HalfFlat.pick_initval(initval=None) + pass + def test_automatically_assigned_test_values(self): # ...because they don't have random number generators. rv = pm.Flat.dist() diff --git a/pymc3/util.py b/pymc3/util.py index 0d733d6718..abdd67e3c2 100644 --- a/pymc3/util.py +++ b/pymc3/util.py @@ -15,13 +15,14 @@ import functools import warnings -from typing import Dict, List, Tuple, Union +from typing import Any, Dict, List, Tuple, Union import arviz import cloudpickle import numpy as np import xarray +from aesara.graph.basic import Variable from cachetools import LRUCache, cachedmethod @@ -38,6 +39,36 @@ def __repr__(self): UNSET = _UnsetType() +def select_initval( + candidate: Any, + default: Any, +) -> Union[int, float, None, np.ndarray, _UnsetType, Variable]: + """Picks a compatible initial value such that it's either numeric, UNSET, a Variable, or None. + + Parameters + ---------- + candidate + A potential initval value. + If this incompatible or UNSET the default will be considered. + Typical values are UNSET, None or instances of ndarray, int, float or Variable. + default + A fallback initval value. + If this and the candidate incompatible `None` will be returned instead. + Typical values are UNSET, None or instances of ndarray, int, float or Variable. + """ + valid_types = (int, float, np.ndarray, list, tuple, type(None), Variable) + valid_candidate = isinstance(candidate, valid_types) or candidate is UNSET + valid_default = isinstance(default, valid_types) or default is UNSET + if isinstance(candidate, valid_types) or (valid_candidate and not valid_default): + return candidate + elif valid_default: + # The candidate is UNSET or incompatible, but a compatible default is available. + return default + # Neither candidate nor default can be used. + # With initval=None the Model with draw an initval randomly. + return None + + def withparent(meth): """Helper wrapper that passes calls to parent's instance"""