Skip to content

Restore support for default initvals #4867

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def polyagamma_cdf(*args, **kwargs):
)
from pymc3.distributions.distribution import Continuous
from pymc3.math import logdiffexp, logit
from pymc3.util import UNSET
from pymc3.util import UNSET, select_initval

__all__ = [
"Uniform",
Expand Down Expand Up @@ -366,6 +366,18 @@ def dist(cls, *, size=None, **kwargs):
res.tag.test_value = np.full(size, floatX(0.0))
return res

@classmethod
def pick_initval(cls, *, size=None, initval=UNSET, **kwargs) -> np.ndarray:
initval = select_initval(
candidate=initval,
default=np.full(size, floatX(0.0)),
)
if initval is None:
raise NotImplementedError(
"The `Flat` distribution does not support random initval sampling (`initval=None`)."
)
return initval

def logp(value):
"""
Calculate log-probability of Flat distribution at specified value.
Expand Down Expand Up @@ -428,6 +440,18 @@ def dist(cls, *, size=None, **kwargs):
res.tag.test_value = np.full(size, floatX(1.0))
return res

@classmethod
def pick_initval(cls, *, size=None, initval=UNSET, **kwargs) -> np.ndarray:
initval = select_initval(
candidate=initval,
default=np.full(size, floatX(1.0)),
)
if initval is None:
raise NotImplementedError(
"The `HalfFlat` distribution does not support random initval sampling (`initval=None`)."
)
return initval

def logp(value):
"""
Calculate log-probability of HalfFlat distribution at specified value.
Expand Down
23 changes: 20 additions & 3 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
import warnings

from abc import ABCMeta
from typing import Optional
from typing import Optional, Union

import aesara
import aesara.tensor as at
import numpy as np

from aesara.graph.basic import Constant, Variable
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.random.var import RandomStateSharedVariable

Expand All @@ -42,7 +44,7 @@
resize_from_observed,
)
from pymc3.printing import str_for_dist
from pymc3.util import UNSET
from pymc3.util import UNSET, select_initval
from pymc3.vartypes import string_types

__all__ = [
Expand Down Expand Up @@ -126,7 +128,7 @@ def __new__(
*args,
rng=None,
dims: Optional[Dims] = None,
initval=None,
initval=UNSET,
observed=None,
total_size=None,
transform=UNSET,
Expand All @@ -149,6 +151,9 @@ def __new__(
initval : optional
Test value to be attached to the output RV.
Must match its shape exactly.
If set to `None`, an initial value will be drawn randomly.
With a value of `UNSET`, or not passing `initval` the `cls` distribution
may provide a default initial value at its own discretion.
observed : optional
Observed data to be passed when registering the random variable in the model.
See ``Model.register_rv``.
Expand Down Expand Up @@ -219,6 +224,10 @@ def __new__(
# A batch size was specified through `dims`, or implied by `observed`.
rv_out = change_rv_size(rv_var=rv_out, new_size=resize_shape, expand=True)

initval = select_initval(
candidate=initval,
default=cls.pick_initval(*args, rng=rng, **kwargs),
)
rv_out = model.register_rv(
rv_out,
name,
Expand Down Expand Up @@ -327,6 +336,14 @@ def dist(

return rv_out

@classmethod
def pick_initval(cls, *args, **kwargs) -> Union[int, float, np.ndarray, Constant, Variable]:
"""Fallback method for creating an initial value for a random variable.

Parameters are the same as for the `.dist()` method.
"""
return None


class NoDistribution(Distribution):
def __init__(
Expand Down
18 changes: 3 additions & 15 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,15 +944,13 @@ def cont_vars(self):
return list(typefilter(self.value_vars, continuous_types))

def set_initval(self, rv_var, initval):
if initval is not None:
if not isinstance(initval, (type(None), TensorVariable)):
initval = rv_var.type.filter(initval)

test_value = getattr(rv_var.tag, "test_value", None)

rv_value_var = self.rvs_to_values[rv_var]
transform = getattr(rv_value_var.tag, "transform", None)

if initval is None or transform:
if transform or isinstance(initval, (type(None), TensorVariable)):
# Sample/evaluate this using the existing initial values, and
# with the least effect on the RNGs involved (i.e. no in-placing)
from aesara.compile.mode import Mode, get_mode
Expand Down Expand Up @@ -981,17 +979,7 @@ def initval_to_rvval(value_var, value):
initval_fn = aesara.function(
[], rv_var, mode=mode, givens=givens, on_unused_input="ignore"
)
try:
initval = initval_fn()
except NotImplementedError as ex:
if "Cannot sample from" in ex.args[0]:
# The RV does not have a random number generator.
# Our last chance is to take the test_value.
# Note that this is a workaround for Flat and HalfFlat
# until an initval default mechanism is implemented (#4752).
initval = test_value
else:
raise
initval = initval_fn()

self.initial_values[rv_value_var] = initval

Expand Down
120 changes: 120 additions & 0 deletions pymc3/tests/test_initvals.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,79 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import aesara
import aesara.tensor as at
import aesara.tensor.random.basic as atr
import numpy as np
import pandas as pd
import pytest

from aesara.graph.basic import Variable

import pymc3 as pm

from pymc3.util import UNSET, select_initval


def test_util_select_initval():
# The candidate is preferred, particularly if the default is invalid
assert select_initval(4, default=None) == 4
assert select_initval([1, 2], default=None) == [1, 2]
assert select_initval(None, default=3) is None
assert select_initval(UNSET, default=pd.Series([1, 2, 3])) is UNSET

# The default is preferred if the candidate is UNSET or invalid
assert select_initval(UNSET, default=3) == 3
assert select_initval(UNSET, default=None) is None
assert select_initval(pd.Series([1, 2, 3]), default=None) is None
assert isinstance(select_initval(UNSET, default=at.scalar()), Variable)
assert isinstance(select_initval(at.scalar(), default=3), Variable)

# None is the fallback if both are invalid
assert select_initval(pd.Series([1, 2, 3]), default="not good") is None
pass


class NormalWithoutInitval(pm.Distribution):
"""A distribution that does not specify a default initial value."""

rv_op = atr.normal

@classmethod
def dist(cls, mu=0, sigma=None, **kwargs):
mu = at.as_tensor(pm.floatX(mu))
sigma = at.as_tensor(pm.floatX(sigma))
return super().dist([mu, sigma], **kwargs)


class UniformWithInitval(pm.distributions.continuous.BoundedContinuous):
"""
A distribution that defaults the initial value.
"""

rv_op = atr.uniform
bound_args_indices = (0, 1) # Lower, Upper

@classmethod
def dist(cls, lower=0, upper=1, initval=UNSET, **kwargs):
lower = at.as_tensor_variable(pm.floatX(lower))
upper = at.as_tensor_variable(pm.floatX(upper))
return super().dist([lower, upper], **kwargs)

@classmethod
def pick_initval(cls, lower, upper, **kwargs):
return (lower + upper) / 2


def transform_fwd(rv, expected_untransformed):
return rv.tag.value_var.tag.transform.forward(rv, expected_untransformed).eval()


class TestInitvalAssignment:
def test_dist_warnings_and_errors(self):
rv = UniformWithInitval.dist(1, 2)
assert not hasattr(rv.tag, "test_value")

with pytest.warns(DeprecationWarning, match="argument is deprecated and has no effect"):
rv = pm.Exponential.dist(lam=1, testval=0.5)
assert not hasattr(rv.tag, "test_value")
Expand All @@ -38,8 +100,66 @@ def test_new_warnings(self):
assert not hasattr(rv.tag, "test_value")
pass

def test_new_initval_behaviors(self):
"""
No test values are set on the RV unless specified by either the user or the RV Op.
But initial values are always determined and managed by the Model object.
"""
with pm.Model() as pmodel:
rv1 = NormalWithoutInitval("default to random draw", 1, 2)
rv2 = NormalWithoutInitval("default to random draw the second", 1, 2)
assert pmodel.initial_values[rv1.tag.value_var] != 1
assert pmodel.initial_values[rv2.tag.value_var] != 1
assert (
pmodel.initial_values[rv1.tag.value_var] != pmodel.initial_values[rv2.tag.value_var]
)
# Randomly drawn initvals are not attached to the rv:
assert not hasattr(rv1.tag, "test_value")
assert not hasattr(rv2.tag, "test_value")

rv = NormalWithoutInitval("user provided", 1, 2, initval=-0.2)
assert pmodel.initial_values[rv.tag.value_var] == np.array(
-0.2, dtype=aesara.config.floatX
)
assert not hasattr(rv.tag, "test_value")

rv = UniformWithInitval("RVOp default", 1.5, 2)
assert pmodel.initial_values[rv.tag.value_var] == transform_fwd(rv, 1.75)
assert not hasattr(rv.tag, "test_value")

rv = UniformWithInitval("user can override RVOp default", 1.5, 2, initval=1.8)
assert pmodel.initial_values[rv.tag.value_var] == transform_fwd(rv, 1.8)
assert not hasattr(rv.tag, "test_value")

rv = UniformWithInitval("user can revert to random draw", 1.5, 2, initval=None)
assert pmodel.initial_values[rv.tag.value_var] != transform_fwd(rv, 1.75)
assert not hasattr(rv.tag, "test_value")
pass

def test_symbolic_initval(self):
"""A regression tests for https://github.com/pymc-devs/pymc3/issues/4911"""
with pm.Model() as pmodel:
a = pm.Normal("a")
b = pm.Normal("b", a, initval=a)
# Initval assignment should evaluate symbolics:
assert isinstance(pmodel.initial_point["b"], np.ndarray)


class TestSpecialDistributions:
def test_flat(self):
pm.Flat.pick_initval(initval=4) == 4
pm.Flat.pick_initval(size=(2,), initval=UNSET) == np.array([0, 0])
with pytest.raises(NotImplementedError, match="does not support random initval"):
pm.Flat.pick_initval(initval=None)
pass

def test_halfflat(self):
pm.HalfFlat.pick_initval(initval=4) == 4
pm.HalfFlat.pick_initval(size=(2,), initval=UNSET) == np.array([1, 1])
with pytest.raises(NotImplementedError, match="does not support random initval"):
pm.HalfFlat.pick_initval(initval=None)
pass

def test_automatically_assigned_test_values(self):
# ...because they don't have random number generators.
rv = pm.Flat.dist()
Expand Down
33 changes: 32 additions & 1 deletion pymc3/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
import functools
import warnings

from typing import Dict, List, Tuple, Union
from typing import Any, Dict, List, Tuple, Union

import arviz
import cloudpickle
import numpy as np
import xarray

from aesara.graph.basic import Variable
from cachetools import LRUCache, cachedmethod


Expand All @@ -38,6 +39,36 @@ def __repr__(self):
UNSET = _UnsetType()


def select_initval(
candidate: Any,
default: Any,
) -> Union[int, float, None, np.ndarray, _UnsetType, Variable]:
"""Picks a compatible initial value such that it's either numeric, UNSET, a Variable, or None.

Parameters
----------
candidate
A potential initval value.
If this incompatible or UNSET the default will be considered.
Typical values are UNSET, None or instances of ndarray, int, float or Variable.
default
A fallback initval value.
If this and the candidate incompatible `None` will be returned instead.
Typical values are UNSET, None or instances of ndarray, int, float or Variable.
"""
valid_types = (int, float, np.ndarray, list, tuple, type(None), Variable)
valid_candidate = isinstance(candidate, valid_types) or candidate is UNSET
valid_default = isinstance(default, valid_types) or default is UNSET
if isinstance(candidate, valid_types) or (valid_candidate and not valid_default):
return candidate
elif valid_default:
# The candidate is UNSET or incompatible, but a compatible default is available.
return default
# Neither candidate nor default can be used.
# With initval=None the Model with draw an initval randomly.
return None


def withparent(meth):
"""Helper wrapper that passes calls to parent's instance"""

Expand Down