Skip to content

Commit 059ef47

Browse files
committed
Emit warning when passing non free_RVs to find_MAP
Also improve docstrings
1 parent abbc5b8 commit 059ef47

File tree

3 files changed

+52
-18
lines changed

3 files changed

+52
-18
lines changed

pymc/aesaraf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,11 +261,11 @@ def expand_replace(var):
261261

262262

263263
def rvs_to_value_vars(
264-
graphs: Iterable[TensorVariable],
264+
graphs: Iterable[Variable],
265265
apply_transforms: bool = True,
266-
initial_replacements: Optional[Dict[TensorVariable, TensorVariable]] = None,
266+
initial_replacements: Optional[Dict[Variable, Variable]] = None,
267267
**kwargs,
268-
) -> List[TensorVariable]:
268+
) -> List[Variable]:
269269
"""Clone and replace random variables in graphs with their value variables.
270270
271271
This will *not* recompute test values in the resulting graphs.

pymc/tests/tuning/test_starting.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import re
15+
1416
import numpy as np
1517
import pytest
1618

@@ -147,3 +149,16 @@ def test_find_MAP_issue_4488():
147149
assert not set.difference({"x_missing", "x_missing_log__", "y"}, set(map_estimate.keys()))
148150
np.testing.assert_allclose(map_estimate["x_missing"], 0.2, rtol=1e-4, atol=1e-4)
149151
np.testing.assert_allclose(map_estimate["y"], [2.0, map_estimate["x_missing"][0] + 1])
152+
153+
154+
def test_find_MAP_warning_non_free_RVs():
155+
with pm.Model() as m:
156+
x = pm.Normal("x")
157+
y = pm.Normal("y")
158+
det = pm.Deterministic("det", x + y)
159+
pm.Normal("z", det, 1e-5, observed=100)
160+
161+
msg = "Intermediate variables (such as Deterministic or Potential) were passed"
162+
with pytest.warns(UserWarning, match=re.escape(msg)):
163+
r = pm.find_MAP(vars=[det])
164+
np.testing.assert_allclose([r["x"], r["y"], r["det"]], [50, 50, 100])

pymc/tuning/starting.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818
@author: johnsalvatier
1919
"""
2020
import sys
21+
import warnings
2122

22-
from typing import Optional
23+
from typing import Optional, Sequence
2324

2425
import aesara.gradient as tg
2526
import numpy as np
2627

28+
from aesara import Variable
2729
from fastprogress.fastprogress import ProgressBar, progress_bar
2830
from numpy import isfinite
2931
from scipy.optimize import minimize
@@ -41,7 +43,7 @@
4143

4244
def find_MAP(
4345
start=None,
44-
vars=None,
46+
vars: Optional[Sequence[Variable]] = None,
4547
method="L-BFGS-B",
4648
return_raw=False,
4749
include_transformed=True,
@@ -61,20 +63,23 @@ def find_MAP(
6163
Parameters
6264
----------
6365
start: `dict` of parameter values (Defaults to `model.initial_point`)
64-
vars: list
65-
List of variables to optimize and set to optimum (Defaults to all continuous).
66-
method: string or callable
67-
Optimization algorithm (Defaults to 'L-BFGS-B' unless
68-
discrete variables are specified in `vars`, then
69-
`Powell` which will perform better). For instructions on use of a callable,
70-
refer to SciPy's documentation of `optimize.minimize`.
71-
return_raw: bool
72-
Whether to return the full output of scipy.optimize.minimize (Defaults to `False`)
66+
These values will be fixed and used for any free RandomVariables that are
67+
not being optimized.
68+
vars: list of TensorVariable
69+
List of free RandomVariables to optimize the posterior with respect to.
70+
Defaults to all continuous RVs in a model. The respective value variables
71+
may also be passed instead.
72+
method: string or callable, optional
73+
Optimization algorithm. Defaults to 'L-BFGS-B' unless discrete variables are
74+
specified in `vars`, then `Powell` which will perform better. For instructions
75+
on use of a callable, refer to SciPy's documentation of `optimize.minimize`.
76+
return_raw: bool, optional defaults to False
77+
Whether to return the full output of scipy.optimize.minimize
7378
include_transformed: bool, optional defaults to True
74-
Flag for reporting automatically transformed variables in addition
75-
to original variables.
79+
Flag for reporting automatically unconstrained transformed values in addition
80+
to the constrained values
7681
progressbar: bool, optional defaults to True
77-
Whether or not to display a progress bar in the command line.
82+
Whether to display a progress bar in the command line.
7883
maxeval: int, optional, defaults to 5000
7984
The maximum number of times the posterior distribution is evaluated.
8085
model: Model (optional if in `with` context)
@@ -95,7 +100,21 @@ def find_MAP(
95100
if not vars:
96101
raise ValueError("Model has no unobserved continuous variables.")
97102
else:
98-
vars = get_value_vars_from_user_vars(vars, model)
103+
try:
104+
vars = get_value_vars_from_user_vars(vars, model)
105+
except ValueError as exc:
106+
# Accomodate case where user passed non-pure RV nodes
107+
vars = pm.inputvars(pm.aesaraf.rvs_to_value_vars(vars))
108+
if vars:
109+
# Make sure they belong to current model again...
110+
vars = get_value_vars_from_user_vars(vars, model)
111+
warnings.warn(
112+
"Intermediate variables (such as Deterministic or Potential) were passed. "
113+
"find_MAP will optimize the underlying free_RVs instead.",
114+
UserWarning,
115+
)
116+
else:
117+
raise exc
99118

100119
disc_vars = list(typefilter(vars, discrete_types))
101120
ipfn = make_initial_point_fn(

0 commit comments

Comments
 (0)