From 3a1cdaddf776e9e23dd6ac8edd894af2e2543d61 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Thu, 21 Mar 2024 13:40:49 -0400 Subject: [PATCH 1/9] Draft of var_names arg for sample --- pymc/backends/__init__.py | 6 +++++- pymc/sampling/mcmc.py | 12 ++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index b274af10b6..b63f68acc6 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -67,6 +67,7 @@ import numpy as np +from pytensor.tensor.variable import TensorVariable from typing_extensions import TypeAlias from pymc.backends.arviz import predictions_to_inference_data, to_inference_data @@ -99,11 +100,12 @@ def _init_trace( stats_dtypes: list[dict[str, type]], trace: Optional[BaseTrace], model: Model, + trace_vars: Optional[list[TensorVariable]] = None, ) -> BaseTrace: """Initializes a trace backend for a chain.""" strace: BaseTrace if trace is None: - strace = NDArray(model=model) + strace = NDArray(model=model, vars=trace_vars) elif isinstance(trace, BaseTrace): if len(trace) > 0: raise ValueError("Continuation of traces is no longer supported.") @@ -123,6 +125,7 @@ def init_traces( step: Union[BlockedStep, CompoundStep], initial_point: Mapping[str, np.ndarray], model: Model, + trace_vars: Optional[list[TensorVariable]] = None, ) -> tuple[Optional[RunType], Sequence[IBaseTrace]]: """Initializes a trace recorder for each chain.""" if HAS_MCB and isinstance(backend, Backend): @@ -142,6 +145,7 @@ def init_traces( chain_number=chain_number, trace=backend, model=model, + trace_vars=trace_vars, ) for chain_number in range(chains) ] diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 313a6ab8c0..4aaf9c8942 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -371,6 +371,7 @@ def sample( random_seed: RandomState = None, progressbar: bool = True, step=None, + var_names: Optional[Sequence[str]] = None, nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, init: str = "auto", @@ -399,6 +400,7 @@ def sample( random_seed: RandomState = None, progressbar: bool = True, step=None, + var_names: Optional[Sequence[str]] = None, nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, init: str = "auto", @@ -427,6 +429,7 @@ def sample( random_seed: RandomState = None, progressbar: bool = True, step=None, + var_names: Optional[Sequence[str]] = None, nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, init: str = "auto", @@ -478,6 +481,8 @@ def sample( A step function or collection of functions. If there are variables without step methods, step methods for those variables will be assigned automatically. By default the NUTS step method will be used, if appropriate to the model. + var_names : list of str + Names of variables to be monitored. If None, all named variables are selected automatically. nuts_sampler : str Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"]. This requires the chosen sampler to be installed. @@ -722,12 +727,18 @@ def sample( model.check_start_vals(ip) _check_start_shape(model, ip) + if var_names is not None: + trace_vars = [v for v in model.unobserved_RVs if v.name in var_names] + else: + trace_vars = None + # Create trace backends for each chain run, traces = init_traces( backend=trace, chains=chains, expected_length=draws + tune, step=step, + trace_vars=trace_vars, initial_point=ip, model=model, ) @@ -739,6 +750,7 @@ def sample( "traces": traces, "chains": chains, "tune": tune, + "var_names": var_names, "progressbar": progressbar, "model": model, "cores": cores, From 911263d7b4cd419e1c828649c5af73369383ab95 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Thu, 21 Mar 2024 13:45:10 -0400 Subject: [PATCH 2/9] Raise exception if any var_names not found --- pymc/sampling/mcmc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 4aaf9c8942..3e5587743d 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -729,6 +729,7 @@ def sample( if var_names is not None: trace_vars = [v for v in model.unobserved_RVs if v.name in var_names] + assert len(trace_vars) == len(var_names), "Not all var_names were found in the model" else: trace_vars = None From 27338863e7810878ebfdcb6ce9b0a94437a3d9be Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Thu, 21 Mar 2024 13:51:36 -0400 Subject: [PATCH 3/9] Simple test --- tests/sampling/test_mcmc.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index 0fc03dd631..a18430818d 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -694,6 +694,15 @@ def test_no_init_nuts_compound(caplog): assert "Initializing NUTS" not in caplog.text +def test_sample_var_names(): + with pm.Model() as model: + a = pm.Normal("a") + b = pm.Deterministic("b", a**2) + idata = pm.sample(10, tune=10, var_names=["a"]) + assert "a" in idata.posterior + assert "b" not in idata.posterior + + class TestAssignStepMethods: def test_bernoulli(self): """Test bernoulli distribution is assigned binary gibbs metropolis method""" From 2092ea1a77afdc2dccbf6cd636a54a3991f04492 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Thu, 21 Mar 2024 21:42:14 -0400 Subject: [PATCH 4/9] Pass var_names to jax sampler --- pymc/sampling/mcmc.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 3e5587743d..c7b128e268 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -264,6 +264,7 @@ def _sample_external_nuts( random_seed: Union[RandomState, None], initvals: Union[StartDict, Sequence[Optional[StartDict]], None], model: Model, + var_names: Optional[Sequence[str]], progressbar: bool, idata_kwargs: Optional[dict], nuts_sampler_kwargs: Optional[dict], @@ -348,6 +349,7 @@ def _sample_external_nuts( random_seed=random_seed, initvals=initvals, model=model, + var_names=var_names, progressbar=progressbar, nuts_sampler=sampler, idata_kwargs=idata_kwargs, @@ -685,6 +687,7 @@ def sample( random_seed=random_seed, initvals=initvals, model=model, + var_names=var_names, progressbar=progressbar, idata_kwargs=idata_kwargs, nuts_sampler_kwargs=nuts_sampler_kwargs, From 55c06cb26d5d512b34f9fc663359223e88d954c4 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Fri, 22 Mar 2024 09:17:19 -0400 Subject: [PATCH 5/9] Update pymc/sampling/mcmc.py Better docstring Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pymc/sampling/mcmc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index c7b128e268..d2c9f051c7 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -483,8 +483,8 @@ def sample( A step function or collection of functions. If there are variables without step methods, step methods for those variables will be assigned automatically. By default the NUTS step method will be used, if appropriate to the model. - var_names : list of str - Names of variables to be monitored. If None, all named variables are selected automatically. + var_names : list of str, optional + Names of variables to be stored in the trace. Defaults to all free variables and deterministics. nuts_sampler : str Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"]. This requires the chosen sampler to be installed. From 16585a934d294f4f7952c210ed6dc0bde5c69bc1 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Fri, 22 Mar 2024 12:17:31 -0400 Subject: [PATCH 6/9] Fixed jax variable selection --- pymc/sampling/jax.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 01f8b0d502..03627ea890 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -532,15 +532,17 @@ def sample_jax_nuts( model = modelcontext(model) - if var_names is None: - var_names = model.unobserved_value_vars + if var_names is not None: + vars_ = [v for v in model.unobserved_value_vars if v.name in var_names] + else: + vars_ = model.unobserved_value_vars if nuts_kwargs is None: nuts_kwargs = {} else: nuts_kwargs = nuts_kwargs.copy() - vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed)) + vars_to_sample = list(get_default_varnames(vars_, include_transformed=keep_untransformed)) (random_seed,) = _get_seeds_per_chain(random_seed, 1) From c9f322440317b93ba9b2620ddfb7c4d8a66ba880 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Fri, 22 Mar 2024 12:20:26 -0400 Subject: [PATCH 7/9] Added jax test --- tests/sampling/test_jax.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/sampling/test_jax.py b/tests/sampling/test_jax.py index d8d0cae246..77121b37ce 100644 --- a/tests/sampling/test_jax.py +++ b/tests/sampling/test_jax.py @@ -491,6 +491,15 @@ def test_sample_partially_observed(): assert idata.posterior["x"].shape == (1, 10, 3) +def test_sample_var_names(): + with pm.Model() as model: + a = pm.Normal("a") + b = pm.Deterministic("b", a**2) + idata = pm.sample(10, tune=10, nuts_sampler="numpyro", var_names=["a"]) + assert "a" in idata.posterior + assert "b" not in idata.posterior + + @pytest.mark.parametrize("nuts_sampler", ("numpyro", "blackjax")) def test_convergence_warnings(caplog, nuts_sampler): with pm.Model() as m: From efbeef954cdca4d61b5e1db2ca82cd6e45ff06f0 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Mon, 25 Mar 2024 11:32:33 -0400 Subject: [PATCH 8/9] Rename filtered varnames --- pymc/sampling/jax.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 03627ea890..f048cc2938 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -533,16 +533,18 @@ def sample_jax_nuts( model = modelcontext(model) if var_names is not None: - vars_ = [v for v in model.unobserved_value_vars if v.name in var_names] + filtered_var_names = [v for v in model.unobserved_value_vars if v.name in var_names] else: - vars_ = model.unobserved_value_vars + filtered_var_names = model.unobserved_value_vars if nuts_kwargs is None: nuts_kwargs = {} else: nuts_kwargs = nuts_kwargs.copy() - vars_to_sample = list(get_default_varnames(vars_, include_transformed=keep_untransformed)) + vars_to_sample = list( + get_default_varnames(filtered_var_names, include_transformed=keep_untransformed) + ) (random_seed,) = _get_seeds_per_chain(random_seed, 1) From 13a5d31510d39dafe8165a5c6159378f7171d0bf Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Mon, 25 Mar 2024 13:09:47 -0400 Subject: [PATCH 9/9] Add warning for var_names with nutpie --- pymc/sampling/mcmc.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index d2c9f051c7..dd97f78c88 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -293,6 +293,11 @@ def _sample_external_nuts( "`idata_kwargs` are currently ignored by the nutpie sampler", UserWarning, ) + if var_names is not None: + warnings.warn( + "`var_names` are currently ignored by the nutpie sampler", + UserWarning, + ) compiled_model = nutpie.compile_pymc_model(model) t_start = time.time() idata = nutpie.sample(