From 88351040e56bc4f1ca26b73182eabb9d2289f785 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Wed, 2 Dec 2020 19:07:01 +0100 Subject: [PATCH 1/8] - Fix regression caused by #4211 --- pymc3/sampling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index a9771d3e55..227d74ff16 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -414,15 +414,15 @@ def sample( """ model = modelcontext(model) if start is None: - start = model.test_point + check_start_vals(model.test_point, model) else: if isinstance(start, dict): update_start_vals(start, model.test_point, model) else: for chain_start_vals in start: update_start_vals(chain_start_vals, model.test_point, model) + check_start_vals(start, model) - check_start_vals(start, model) if cores is None: cores = min(4, _cpu_count()) From 7aa0a3612a6bf106100879cd84e0e1b4c302eeda Mon Sep 17 00:00:00 2001 From: Ricardo Date: Thu, 3 Dec 2020 10:51:22 +0100 Subject: [PATCH 2/8] - Add test to make sure jitter is being applied to chains starting points by default --- pymc3/sampling.py | 2 +- pymc3/tests/test_sampling.py | 23 ++++++++++++++++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 227d74ff16..8fc7118e8e 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -490,8 +490,8 @@ def sample( progressbar=progressbar, **kwargs, ) - check_start_vals(start_, model) if start is None: + check_start_vals(start_, model) start = start_ except (AttributeError, NotImplementedError, tg.NullTypeGradError): # gradient computation failed diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index 2185542f17..c6e5f34e86 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import nullcontext as does_not_raise from itertools import combinations from typing import Tuple import numpy as np @@ -25,7 +26,7 @@ import theano from pymc3.tests.models import simple_init from pymc3.tests.helpers import SeededTest -from pymc3.exceptions import IncorrectArgumentsError +from pymc3.exceptions import IncorrectArgumentsError, SamplingError from scipy import stats import pytest @@ -761,6 +762,26 @@ def test_exec_nuts_init(method): assert "a" in start[0] and "b_log__" in start[0] +@pytest.mark.parametrize( + "init, start, expectation", + [ + ("auto", None, pytest.raises(SamplingError)), + ("jitter+adapt_diag", None, pytest.raises(SamplingError)), + ("auto", {"x": 0}, does_not_raise()), + ("jitter+adapt_diag", {'x': 0}, does_not_raise()), + ], +) +def test_default_sample_nuts_jitter(init, start, expectation): + # Random seed was selected to make sure initialization with "jitter+adapt_diag" would fail. + # This will need to be changed in the future if initialization or randomization method changes + # or if default initialization is made more robust. + with pm.Model() as m: + x = pm.HalfNormal('x', transform=None) + with expectation: + pm.sample(tune=1, draws=0, chains=4, random_seed=7, + init=init, start=start) + + @pytest.fixture(scope="class") def point_list_arg_bug_fixture() -> Tuple[pm.Model, pm.backends.base.MultiTrace]: with pm.Model() as pmodel: From 25cbaba6a3324a4a98fe7ffa4031ccd8e61df23c Mon Sep 17 00:00:00 2001 From: Ricardo Date: Thu, 3 Dec 2020 11:08:41 +0100 Subject: [PATCH 3/8] - Import appropriate empty context for python < 3.7 --- pymc3/tests/test_sampling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index c6e5f34e86..d04c58484d 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import nullcontext as does_not_raise +from contextlib import ExitStack as does_not_raise from itertools import combinations from typing import Tuple import numpy as np @@ -778,7 +778,7 @@ def test_default_sample_nuts_jitter(init, start, expectation): with pm.Model() as m: x = pm.HalfNormal('x', transform=None) with expectation: - pm.sample(tune=1, draws=0, chains=4, random_seed=7, + pm.sample(tune=1, draws=0, chains=1, random_seed=7, init=init, start=start) From e77494268160304d69851b8fe32f77c539391707 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Thu, 3 Dec 2020 11:14:27 +0100 Subject: [PATCH 4/8] - Apply black formatting --- pymc3/tests/test_sampling.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index d04c58484d..5ee8aeea1e 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -768,7 +768,7 @@ def test_exec_nuts_init(method): ("auto", None, pytest.raises(SamplingError)), ("jitter+adapt_diag", None, pytest.raises(SamplingError)), ("auto", {"x": 0}, does_not_raise()), - ("jitter+adapt_diag", {'x': 0}, does_not_raise()), + ("jitter+adapt_diag", {"x": 0}, does_not_raise()), ], ) def test_default_sample_nuts_jitter(init, start, expectation): @@ -776,10 +776,9 @@ def test_default_sample_nuts_jitter(init, start, expectation): # This will need to be changed in the future if initialization or randomization method changes # or if default initialization is made more robust. with pm.Model() as m: - x = pm.HalfNormal('x', transform=None) + x = pm.HalfNormal("x", transform=None) with expectation: - pm.sample(tune=1, draws=0, chains=1, random_seed=7, - init=init, start=start) + pm.sample(tune=1, draws=0, chains=1, random_seed=7, init=init, start=start) @pytest.fixture(scope="class") From 38d21c831544b127f3d3605174ba79860aa5ed20 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Thu, 3 Dec 2020 12:09:15 +0100 Subject: [PATCH 5/8] - Change the second check_start_vals to explicitly run on the newly assigned start variable. --- pymc3/sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 8fc7118e8e..a97e8dd707 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -491,8 +491,8 @@ def sample( **kwargs, ) if start is None: - check_start_vals(start_, model) start = start_ + check_start_vals(start, model) except (AttributeError, NotImplementedError, tg.NullTypeGradError): # gradient computation failed _log.info("Initializing NUTS failed. " "Falling back to elementwise auto-assignment.") From 98187add910174365febe6617f4243a9500db131 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Thu, 3 Dec 2020 12:27:35 +0100 Subject: [PATCH 6/8] - Improve test documentation and add a new condition --- pymc3/tests/test_sampling.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index 5ee8aeea1e..85b96a477e 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -769,11 +769,15 @@ def test_exec_nuts_init(method): ("jitter+adapt_diag", None, pytest.raises(SamplingError)), ("auto", {"x": 0}, does_not_raise()), ("jitter+adapt_diag", {"x": 0}, does_not_raise()), + ("adapt_diag", None, does_not_raise()), ], ) def test_default_sample_nuts_jitter(init, start, expectation): - # Random seed was selected to make sure initialization with "jitter+adapt_diag" would fail. - # This will need to be changed in the future if initialization or randomization method changes + # This test tries to check whether the starting points returned by init_nuts are actually being + # used when pm.sample() is called without specifying an explicit start point (see + # https://github.com/pymc-devs/pymc3/pull/4285). + # A random seed was selected to make sure the initialization with "jitter+adapt_diag" would fail. + # This will need to be changed in the future if the initialization or randomization method changes # or if default initialization is made more robust. with pm.Model() as m: x = pm.HalfNormal("x", transform=None) From 8699dad3c2b4ae4ca6ffc1f33dbbace104eee308 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Thu, 3 Dec 2020 14:02:54 +0100 Subject: [PATCH 7/8] Use monkeypatch for more robust test --- pymc3/tests/test_sampling.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index 85b96a477e..7ae8bdc45e 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -772,17 +772,22 @@ def test_exec_nuts_init(method): ("adapt_diag", None, does_not_raise()), ], ) -def test_default_sample_nuts_jitter(init, start, expectation): - # This test tries to check whether the starting points returned by init_nuts are actually being - # used when pm.sample() is called without specifying an explicit start point (see +def test_default_sample_nuts_jitter(init, start, expectation, monkeypatch): + # This test tries to check whether the starting points returned by init_nuts are actually + # being used when pm.sample() is called without specifying an explicit start point (see # https://github.com/pymc-devs/pymc3/pull/4285). - # A random seed was selected to make sure the initialization with "jitter+adapt_diag" would fail. - # This will need to be changed in the future if the initialization or randomization method changes - # or if default initialization is made more robust. + def _mocked_init_nuts(*args, **kwargs): + if init == 'adapt_diag': + start_ = [{'x': np.array(0.79788456)}] + else: + start_ = [{'x': np.array(-0.04949886)}] + _, step = pm.init_nuts(*args, **kwargs) + return start_, step + monkeypatch.setattr('pymc3.sampling.init_nuts', _mocked_init_nuts) with pm.Model() as m: x = pm.HalfNormal("x", transform=None) with expectation: - pm.sample(tune=1, draws=0, chains=1, random_seed=7, init=init, start=start) + pm.sample(tune=1, draws=0, chains=1, init=init, start=start) @pytest.fixture(scope="class") From 6699d664651cc169355d194bf4ceee379491e764 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Thu, 3 Dec 2020 14:09:19 +0100 Subject: [PATCH 8/8] - Black formatting, once again... --- pymc3/tests/test_sampling.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index 7ae8bdc45e..2ccaf80623 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -777,13 +777,14 @@ def test_default_sample_nuts_jitter(init, start, expectation, monkeypatch): # being used when pm.sample() is called without specifying an explicit start point (see # https://github.com/pymc-devs/pymc3/pull/4285). def _mocked_init_nuts(*args, **kwargs): - if init == 'adapt_diag': - start_ = [{'x': np.array(0.79788456)}] + if init == "adapt_diag": + start_ = [{"x": np.array(0.79788456)}] else: - start_ = [{'x': np.array(-0.04949886)}] + start_ = [{"x": np.array(-0.04949886)}] _, step = pm.init_nuts(*args, **kwargs) return start_, step - monkeypatch.setattr('pymc3.sampling.init_nuts', _mocked_init_nuts) + + monkeypatch.setattr("pymc3.sampling.init_nuts", _mocked_init_nuts) with pm.Model() as m: x = pm.HalfNormal("x", transform=None) with expectation: