Skip to content

Commit abbc5b8

Browse files
committed
Implement get_value_vars_from_user_vars util to check sampler vars are valid
Deprecates old `allinmodel`
1 parent 5ec2041 commit abbc5b8

File tree

8 files changed

+88
-54
lines changed

8 files changed

+88
-54
lines changed

pymc/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
UNSET,
7070
WithMemoization,
7171
get_transformed_name,
72+
get_value_vars_from_user_vars,
7273
get_var_name,
7374
treedict,
7475
treelist,
@@ -617,6 +618,7 @@ def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs):
617618
if grad_vars is None:
618619
grad_vars = self.continuous_value_vars
619620
else:
621+
grad_vars = get_value_vars_from_user_vars(grad_vars, self)
620622
for i, var in enumerate(grad_vars):
621623
if var.dtype not in continuous_types:
622624
raise ValueError(f"Can only compute the gradient of continuous types: {var}")

pymc/step_methods/hmc/base_hmc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from pymc.step_methods.hmc import integration
3232
from pymc.step_methods.hmc.quadpotential import QuadPotentialDiagAdapt, quad_potential
3333
from pymc.tuning import guess_scaling
34+
from pymc.util import get_value_vars_from_user_vars
3435

3536
logger = logging.getLogger("pymc")
3637

@@ -91,8 +92,7 @@ def __init__(
9192
if vars is None:
9293
vars = self._model.continuous_value_vars
9394
else:
94-
vars = [self._model.rvs_to_values.get(var, var) for var in vars]
95-
95+
vars = get_value_vars_from_user_vars(vars, self._model)
9696
super().__init__(vars, blocked=blocked, model=self._model, dtype=dtype, **aesara_kwargs)
9797

9898
self.adapt_step_size = adapt_step_size

pymc/step_methods/metropolis.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@
5656
"MultivariateNormalProposal",
5757
]
5858

59+
from pymc.util import get_value_vars_from_user_vars
60+
5961
# Available proposal distributions for Metropolis
6062

6163

@@ -176,9 +178,7 @@ def __init__(
176178
if vars is None:
177179
vars = model.value_vars
178180
else:
179-
vars = [model.rvs_to_values.get(var, var) for var in vars]
180-
181-
vars = pm.inputvars(vars)
181+
vars = get_value_vars_from_user_vars(vars, model)
182182

183183
initial_values_shape = [initial_values[v.name].shape for v in vars]
184184
if S is None:
@@ -394,7 +394,7 @@ def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None):
394394
self.steps_until_tune = tune_interval
395395
self.accepted = 0
396396

397-
vars = [model.rvs_to_values.get(var, var) for var in vars]
397+
vars = get_value_vars_from_user_vars(vars, model)
398398

399399
if not all([v.dtype in pm.discrete_types for v in vars]):
400400
raise ValueError("All variables must be Bernoulli for BinaryMetropolis")
@@ -484,8 +484,9 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None):
484484
# transition probabilities
485485
self.transit_p = transit_p
486486

487+
vars = get_value_vars_from_user_vars(vars, model)
488+
487489
initial_point = model.initial_point()
488-
vars = [model.rvs_to_values.get(var, var) for var in vars]
489490
self.dim = sum(initial_point[v.name].size for v in vars)
490491

491492
if order == "random":
@@ -566,8 +567,7 @@ def __init__(self, vars, proposal="uniform", order="random", model=None):
566567

567568
model = pm.modelcontext(model)
568569

569-
vars = [model.rvs_to_values.get(var, var) for var in vars]
570-
vars = pm.inputvars(vars)
570+
vars = get_value_vars_from_user_vars(vars, model)
571571

572572
initial_point = model.initial_point()
573573

@@ -777,8 +777,7 @@ def __init__(
777777
if vars is None:
778778
vars = model.continuous_value_vars
779779
else:
780-
vars = [model.rvs_to_values.get(var, var) for var in vars]
781-
vars = pm.inputvars(vars)
780+
vars = get_value_vars_from_user_vars(vars, model)
782781

783782
if S is None:
784783
S = np.ones(initial_values_size)
@@ -928,8 +927,7 @@ def __init__(
928927
if vars is None:
929928
vars = model.continuous_value_vars
930929
else:
931-
vars = [model.rvs_to_values.get(var, var) for var in vars]
932-
vars = pm.inputvars(vars)
930+
vars = get_value_vars_from_user_vars(vars, model)
933931

934932
if S is None:
935933
S = np.ones(initial_values_size)

pymc/step_methods/slicer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
import numpy as np
1818
import numpy.random as nr
1919

20-
from pymc.aesaraf import inputvars
2120
from pymc.blocking import RaveledVars
2221
from pymc.model import modelcontext
2322
from pymc.step_methods.arraystep import ArrayStep, Competence
23+
from pymc.util import get_value_vars_from_user_vars
2424
from pymc.vartypes import continuous_types
2525

2626
__all__ = ["Slice"]
@@ -65,8 +65,7 @@ def __init__(self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, *
6565
if vars is None:
6666
vars = self.model.continuous_value_vars
6767
else:
68-
vars = [self.model.rvs_to_values.get(var, var) for var in vars]
69-
vars = inputvars(vars)
68+
vars = get_value_vars_from_user_vars(vars, self.model)
7069

7170
super().__init__(vars, [self.model.compile_logp()], **kwargs)
7271

pymc/tests/test_util.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
_get_seeds_per_chain,
2929
dataset_to_point_list,
3030
drop_warning_stat,
31+
get_value_vars_from_user_vars,
3132
hash_key,
3233
hashable,
3334
locally_cachedmethod,
@@ -236,3 +237,32 @@ def test_get_seeds_per_chain():
236237

237238
with pytest.raises(ValueError, match=re.escape("The `seeds` must be array-like")):
238239
_get_seeds_per_chain({1: 1, 2: 2}, 2)
240+
241+
242+
def test_get_value_vars_from_user_vars():
243+
with pm.Model() as model1:
244+
x1 = pm.Normal("x1", mu=0, sigma=1)
245+
y1 = pm.Normal("y1", mu=0, sigma=1)
246+
247+
x1_value = model1.rvs_to_values[x1]
248+
y1_value = model1.rvs_to_values[y1]
249+
assert get_value_vars_from_user_vars([x1, y1], model1) == [x1_value, y1_value]
250+
assert get_value_vars_from_user_vars([x1], model1) == [x1_value]
251+
# The next line does not wrap the variable in a list on purpose, to test the
252+
# utility function can handle those as promised
253+
assert get_value_vars_from_user_vars(x1_value, model1) == [x1_value]
254+
255+
with pm.Model() as model2:
256+
x2 = pm.Normal("x2", mu=0, sigma=1)
257+
y2 = pm.Normal("y2", mu=0, sigma=1)
258+
det2 = pm.Deterministic("det2", x2 + y2)
259+
260+
prefix = "The following variables are not random variables in the model:"
261+
with pytest.raises(ValueError, match=rf"{prefix} \['x2', 'y2'\]"):
262+
get_value_vars_from_user_vars([x2, y2], model1)
263+
with pytest.raises(ValueError, match=rf"{prefix} \['x2'\]"):
264+
get_value_vars_from_user_vars([x2, y1], model1)
265+
with pytest.raises(ValueError, match=rf"{prefix} \['x2'\]"):
266+
get_value_vars_from_user_vars([x2], model1)
267+
with pytest.raises(ValueError, match=rf"{prefix} \['det2'\]"):
268+
get_value_vars_from_user_vars([det2], model2)

pymc/tests/tuning/test_starting.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pymc.tests.checks import close_to
2323
from pymc.tests.helpers import select_by_precision
2424
from pymc.tests.models import non_normal, simple_arbitrary_det, simple_model
25-
from pymc.tuning import find_MAP, starting
25+
from pymc.tuning import find_MAP
2626

2727

2828
@pytest.mark.parametrize("bounded", [False, True])
@@ -147,28 +147,3 @@ def test_find_MAP_issue_4488():
147147
assert not set.difference({"x_missing", "x_missing_log__", "y"}, set(map_estimate.keys()))
148148
np.testing.assert_allclose(map_estimate["x_missing"], 0.2, rtol=1e-4, atol=1e-4)
149149
np.testing.assert_allclose(map_estimate["y"], [2.0, map_estimate["x_missing"][0] + 1])
150-
151-
152-
def test_allinmodel():
153-
model1 = pm.Model()
154-
model2 = pm.Model()
155-
with model1:
156-
x1 = pm.Normal("x1", mu=0, sigma=1)
157-
y1 = pm.Normal("y1", mu=0, sigma=1)
158-
with model2:
159-
x2 = pm.Normal("x2", mu=0, sigma=1)
160-
y2 = pm.Normal("y2", mu=0, sigma=1)
161-
162-
x1 = model1.rvs_to_values[x1]
163-
y1 = model1.rvs_to_values[y1]
164-
x2 = model2.rvs_to_values[x2]
165-
y2 = model2.rvs_to_values[y2]
166-
167-
starting.allinmodel([x1, y1], model1)
168-
starting.allinmodel([x1], model1)
169-
with pytest.raises(ValueError, match=r"Some variables not in the model: \['x2', 'y2'\]"):
170-
starting.allinmodel([x2, y2], model1)
171-
with pytest.raises(ValueError, match=r"Some variables not in the model: \['x2'\]"):
172-
starting.allinmodel([x2, y1], model1)
173-
with pytest.raises(ValueError, match=r"Some variables not in the model: \['x2'\]"):
174-
starting.allinmodel([x2], model1)

pymc/tuning/starting.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,10 @@
3030

3131
import pymc as pm
3232

33-
from pymc.aesaraf import inputvars
3433
from pymc.blocking import DictToArrayBijection, RaveledVars
3534
from pymc.initial_point import make_initial_point_fn
3635
from pymc.model import modelcontext
37-
from pymc.util import get_default_varnames, get_var_name
36+
from pymc.util import get_default_varnames, get_value_vars_from_user_vars
3837
from pymc.vartypes import discrete_types, typefilter
3938

4039
__all__ = ["find_MAP"]
@@ -96,11 +95,9 @@ def find_MAP(
9695
if not vars:
9796
raise ValueError("Model has no unobserved continuous variables.")
9897
else:
99-
vars = [model.rvs_to_values.get(var, var) for var in vars]
98+
vars = get_value_vars_from_user_vars(vars, model)
10099

101-
vars = inputvars(vars)
102100
disc_vars = list(typefilter(vars, discrete_types))
103-
allinmodel(vars, model)
104101
ipfn = make_initial_point_fn(
105102
model=model,
106103
jitter_rvs=set(),
@@ -182,13 +179,6 @@ def allfinite(x):
182179
return np.all(isfinite(x))
183180

184181

185-
def allinmodel(vars, model):
186-
notin = [v for v in vars if v not in model.value_vars]
187-
if notin:
188-
notin = list(map(get_var_name, notin))
189-
raise ValueError("Some variables not in the model: " + str(notin))
190-
191-
192182
class CostFuncWrapper:
193183
def __init__(self, maxeval=5000, progressbar=True, logp_func=None, dlogp_func=None):
194184
self.n_eval = 0

pymc/util.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import numpy as np
2222
import xarray
2323

24+
from aesara import Variable
2425
from aesara.compile import SharedVariable
2526
from cachetools import LRUCache, cachedmethod
2627

@@ -441,3 +442,42 @@ def _get_unique_seeds_per_chain(integers_fn):
441442
)
442443

443444
return random_state
445+
446+
447+
def get_value_vars_from_user_vars(
448+
vars: Union[Variable, Sequence[Variable]], model
449+
) -> List[Variable]:
450+
"""This function converts user "vars" input into value variables
451+
452+
More often than not, users will pass random variables, and we will extract the
453+
respective value variables, but we also allow for the input to already be value
454+
variables, in case the function is called internally or by a "super-user"
455+
456+
Returns
457+
-------
458+
value_vars: list of TensorVariable
459+
List of model value variables that correspond to the input vars
460+
461+
Raises
462+
------
463+
ValueError:
464+
If any of the provided variables do not correspond to any model value variable
465+
"""
466+
if not isinstance(vars, Sequence):
467+
# Single var was passed
468+
value_vars = [model.rvs_to_values.get(vars, vars)]
469+
else:
470+
value_vars = [model.rvs_to_values.get(var, var) for var in vars]
471+
472+
# Check that we only have value vars from the model
473+
model_value_vars = model.value_vars
474+
notin = [v for v in value_vars if v not in model_value_vars]
475+
if notin:
476+
notin = list(map(get_var_name, notin))
477+
# We mention random variables, even though the input may be a wrong value variable
478+
# because most users don't know about that duality
479+
raise ValueError(
480+
"The following variables are not random variables in the model: " + str(notin)
481+
)
482+
483+
return value_vars

0 commit comments

Comments
 (0)