Skip to content

Commit 888fe7b

Browse files
committed
Add helper to compute log_likelihood and stop computing it by default
1 parent 5692fc0 commit 888fe7b

File tree

12 files changed

+477
-193
lines changed

12 files changed

+477
-193
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ jobs:
6363
pymc/tests/sampling/test_forward.py
6464
pymc/tests/sampling/test_population.py
6565
pymc/tests/stats/test_convergence.py
66+
pymc/tests/stats/test_log_likelihood.py
6667
6768
- |
6869
pymc/tests/tuning/test_scaling.py

docs/source/api/misc.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Other utils
66
.. autosummary::
77
:toctree: generated/
88

9+
compute_log_likelihood
910
find_constrained_prior
1011
DictToArrayBijection
1112

docs/source/learn/core_notebooks/model_comparison.ipynb

Lines changed: 189 additions & 67 deletions
Large diffs are not rendered by default.

pymc/backends/arviz.py

Lines changed: 15 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from arviz.data.base import CoordSpec, DimSpec, dict_to_dataset, requires
2222
from pytensor.graph.basic import Constant
2323
from pytensor.tensor.sharedvar import SharedVariable
24-
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
2524

2625
import pymc
2726

@@ -153,7 +152,7 @@ def __init__(
153152
trace=None,
154153
prior=None,
155154
posterior_predictive=None,
156-
log_likelihood=True,
155+
log_likelihood=False,
157156
predictions=None,
158157
coords: Optional[CoordSpec] = None,
159158
dims: Optional[DimSpec] = None,
@@ -246,68 +245,6 @@ def split_trace(self) -> Tuple[Union[None, "MultiTrace"], Union[None, "MultiTrac
246245
trace_posterior = self.trace[self.ntune :]
247246
return trace_posterior, trace_warmup
248247

249-
def log_likelihood_vals_point(self, point, var, log_like_fun):
250-
"""Compute log likelihood for each observed point."""
251-
# TODO: This is a cheap hack; we should filter-out the correct
252-
# variables some other way
253-
point = {i.name: point[i.name] for i in log_like_fun.f.maker.inputs if i.name in point}
254-
log_like_val = np.atleast_1d(log_like_fun(point))
255-
256-
if isinstance(var.owner.op, (AdvancedIncSubtensor, AdvancedIncSubtensor1)):
257-
try:
258-
obs_data = extract_obs_data(self.model.rvs_to_values[var])
259-
except TypeError:
260-
warnings.warn(f"Could not extract data from symbolic observation {var}")
261-
262-
mask = obs_data.mask
263-
if np.ndim(mask) > np.ndim(log_like_val):
264-
mask = np.any(mask, axis=-1)
265-
log_like_val = np.where(mask, np.nan, log_like_val)
266-
return log_like_val
267-
268-
def _extract_log_likelihood(self, trace):
269-
"""Compute log likelihood of each observation."""
270-
if self.trace is None:
271-
return None
272-
if self.model is None:
273-
return None
274-
275-
# TODO: We no longer need one function per observed variable
276-
if self.log_likelihood is True:
277-
cached = [
278-
(
279-
var,
280-
self.model.compile_fn(
281-
self.model.logp(var, sum=False)[0],
282-
inputs=self.model.value_vars,
283-
on_unused_input="ignore",
284-
),
285-
)
286-
for var in self.model.observed_RVs
287-
]
288-
else:
289-
cached = [
290-
(
291-
var,
292-
self.model.compile_fn(
293-
self.model.logp(var, sum=False)[0],
294-
inputs=self.model.value_vars,
295-
on_unused_input="ignore",
296-
),
297-
)
298-
for var in self.model.observed_RVs
299-
if var.name in self.log_likelihood
300-
]
301-
log_likelihood_dict = _DefaultTrace(len(trace.chains))
302-
for var, log_like_fun in cached:
303-
for k, chain in enumerate(trace.chains):
304-
log_like_chain = [
305-
self.log_likelihood_vals_point(point, var, log_like_fun)
306-
for point in trace.points([chain])
307-
]
308-
log_likelihood_dict.insert(var.name, np.stack(log_like_chain), k)
309-
return log_likelihood_dict.trace_dict
310-
311248
@requires("trace")
312249
def posterior_to_xarray(self):
313250
"""Convert the posterior to an xarray dataset."""
@@ -382,49 +319,6 @@ def sample_stats_to_xarray(self):
382319
),
383320
)
384321

385-
@requires("trace")
386-
@requires("model")
387-
def log_likelihood_to_xarray(self):
388-
"""Extract log likelihood and log_p data from PyMC trace."""
389-
if self.predictions or not self.log_likelihood:
390-
return None
391-
data_warmup = {}
392-
data = {}
393-
warn_msg = (
394-
"Could not compute log_likelihood, it will be omitted. "
395-
"Check your model object or set log_likelihood=False"
396-
)
397-
if self.posterior_trace:
398-
try:
399-
data = self._extract_log_likelihood(self.posterior_trace)
400-
except TypeError:
401-
warnings.warn(warn_msg)
402-
if self.warmup_trace:
403-
try:
404-
data_warmup = self._extract_log_likelihood(self.warmup_trace)
405-
except TypeError:
406-
warnings.warn(warn_msg)
407-
return (
408-
dict_to_dataset(
409-
data,
410-
library=pymc,
411-
dims=self.dims,
412-
coords=self.coords,
413-
skip_event_dims=True,
414-
),
415-
dict_to_dataset(
416-
data_warmup,
417-
library=pymc,
418-
dims=self.dims,
419-
coords=self.coords,
420-
skip_event_dims=True,
421-
),
422-
)
423-
424-
return dict_to_dataset(
425-
data, library=pymc, coords=self.coords, dims=self.dims, default_dims=self.sample_dims
426-
)
427-
428322
@requires(["posterior_predictive"])
429323
def posterior_predictive_to_xarray(self):
430324
"""Convert posterior_predictive samples to xarray."""
@@ -509,7 +403,6 @@ def to_inference_data(self):
509403
id_dict = {
510404
"posterior": self.posterior_to_xarray(),
511405
"sample_stats": self.sample_stats_to_xarray(),
512-
"log_likelihood": self.log_likelihood_to_xarray(),
513406
"posterior_predictive": self.posterior_predictive_to_xarray(),
514407
"predictions": self.predictions_to_xarray(),
515408
**self.priors_to_xarray(),
@@ -519,15 +412,27 @@ def to_inference_data(self):
519412
id_dict["predictions_constant_data"] = self.constant_data_to_xarray()
520413
else:
521414
id_dict["constant_data"] = self.constant_data_to_xarray()
522-
return InferenceData(save_warmup=self.save_warmup, **id_dict)
415+
idata = InferenceData(save_warmup=self.save_warmup, **id_dict)
416+
if self.log_likelihood:
417+
from pymc.stats.log_likelihood import compute_log_likelihood
418+
419+
idata = compute_log_likelihood(
420+
idata,
421+
var_names=None if self.log_likelihood is True else self.log_likelihood,
422+
extend_inferencedata=True,
423+
model=self.model,
424+
sample_dims=self.sample_dims,
425+
progressbar=False,
426+
)
427+
return idata
523428

524429

525430
def to_inference_data(
526431
trace: Optional["MultiTrace"] = None,
527432
*,
528433
prior: Optional[Mapping[str, Any]] = None,
529434
posterior_predictive: Optional[Mapping[str, Any]] = None,
530-
log_likelihood: Union[bool, Iterable[str]] = True,
435+
log_likelihood: Union[bool, Iterable[str]] = False,
531436
coords: Optional[CoordSpec] = None,
532437
dims: Optional[DimSpec] = None,
533438
sample_dims: Optional[List] = None,

pymc/sampling/jax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ def sample_blackjax_nuts(
412412
else:
413413
idata_kwargs = idata_kwargs.copy()
414414

415-
if idata_kwargs.pop("log_likelihood", bool(model.observed_RVs)):
415+
if idata_kwargs.pop("log_likelihood", False):
416416
tic5 = datetime.now()
417417
print("Computing Log Likelihood...", file=sys.stdout)
418418
log_likelihood = _get_log_likelihood(
@@ -634,7 +634,7 @@ def sample_numpyro_nuts(
634634
else:
635635
idata_kwargs = idata_kwargs.copy()
636636

637-
if idata_kwargs.pop("log_likelihood", bool(model.observed_RVs)):
637+
if idata_kwargs.pop("log_likelihood", False):
638638
tic5 = datetime.now()
639639
print("Computing Log Likelihood...", file=sys.stdout)
640640
log_likelihood = _get_log_likelihood(

pymc/stats/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,6 @@
2727
if not attr.startswith("__"):
2828
setattr(sys.modules[__name__], attr, obj)
2929

30+
from pymc.stats.log_likelihood import compute_log_likelihood
3031

31-
__all__ = tuple(az.stats.__all__)
32+
__all__ = ("compute_log_likelihood",) + tuple(az.stats.__all__)

pymc/stats/log_likelihood.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright 2022 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Optional, Sequence
15+
16+
import numpy as np
17+
18+
from arviz import InferenceData, dict_to_dataset
19+
from fastprogress import progress_bar
20+
21+
import pymc
22+
23+
from pymc.backends.arviz import _DefaultTrace
24+
from pymc.model import Model, modelcontext
25+
from pymc.util import dataset_to_point_list
26+
27+
__all__ = ("compute_log_likelihood",)
28+
29+
30+
def compute_log_likelihood(
31+
idata: InferenceData,
32+
*,
33+
var_names: Optional[Sequence[str]] = None,
34+
extend_inferencedata: bool = True,
35+
model: Optional[Model] = None,
36+
sample_dims: Sequence[str] = ("chain", "draw"),
37+
progressbar=True,
38+
):
39+
"""Compute elemwise log_likelihood of model given InferenceData with posterior group
40+
41+
Parameters
42+
----------
43+
idata : InferenceData
44+
InferenceData with posterior group
45+
var_names : sequence of str, optional
46+
List of Observed variable names for which to compute log_likelihood. Defaults to all observed variables
47+
extend_inferencedata : bool, default True
48+
Whether to extend the original InferenceData or return a new one
49+
model : Model, optional
50+
sample_dims : sequence of str, default ("chain", "draw")
51+
progressbar : bool, default True
52+
53+
Returns
54+
-------
55+
idata : InferenceData
56+
InferenceData with log_likelihood group
57+
58+
"""
59+
60+
posterior = idata["posterior"]
61+
62+
model = modelcontext(model)
63+
64+
if var_names is None:
65+
observed_vars = model.observed_RVs
66+
var_names = tuple(rv.name for rv in observed_vars)
67+
else:
68+
observed_vars = [model.named_vars[name] for name in var_names]
69+
if not set(observed_vars).issubset(model.observed_RVs):
70+
raise ValueError(f"var_names must refer to observed_RVs in the model. Got: {var_names}")
71+
72+
# We need to temporarily disable transforms, because the InferenceData only keeps the untransformed values
73+
# pylint: disable=used-before-assignment
74+
try:
75+
original_rvs_to_values = model.rvs_to_values
76+
original_rvs_to_transforms = model.rvs_to_transforms
77+
78+
model.rvs_to_values = {
79+
rv: rv.clone() if rv not in model.observed_RVs else value
80+
for rv, value in model.rvs_to_values.items()
81+
}
82+
model.rvs_to_transforms = {rv: None for rv in model.basic_RVs}
83+
84+
elemwise_loglike_fn = model.compile_fn(
85+
inputs=model.value_vars,
86+
outs=model.logp(vars=observed_vars, sum=False),
87+
on_unused_input="ignore",
88+
)
89+
finally:
90+
model.rvs_to_values = original_rvs_to_values
91+
model.rvs_to_transforms = original_rvs_to_transforms
92+
# pylint: enable=used-before-assignment
93+
94+
# Ignore Deterministics
95+
posterior_values = posterior[[rv.name for rv in model.free_RVs]]
96+
posterior_pts, stacked_dims = dataset_to_point_list(posterior_values, sample_dims)
97+
n_pts = len(posterior_pts)
98+
loglike_dict = _DefaultTrace(n_pts)
99+
indices = range(n_pts)
100+
if progressbar:
101+
indices = progress_bar(indices, total=n_pts, display=progressbar)
102+
103+
for idx in indices:
104+
loglikes_pts = elemwise_loglike_fn(posterior_pts[idx])
105+
for rv_name, rv_loglike in zip(var_names, loglikes_pts):
106+
loglike_dict.insert(rv_name, rv_loglike, idx)
107+
108+
loglike_trace = loglike_dict.trace_dict
109+
for key, array in loglike_trace.items():
110+
loglike_trace[key] = array.reshape(
111+
(*[len(coord) for coord in stacked_dims.values()], *array.shape[1:])
112+
)
113+
114+
loglike_dataset = dict_to_dataset(
115+
loglike_trace,
116+
library=pymc,
117+
dims={dname: list(dvals) for dname, dvals in model.named_vars_to_dims.items()},
118+
coords={
119+
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
120+
for cname, cvals in model.coords.items()
121+
},
122+
default_dims=list(sample_dims),
123+
skip_event_dims=True,
124+
)
125+
126+
if extend_inferencedata:
127+
idata.add_groups(dict(log_likelihood=loglike_dataset))
128+
return idata
129+
else:
130+
return loglike_dataset

0 commit comments

Comments
 (0)