Skip to content

Commit cf2fff7

Browse files
committed
Implement observe and do model transformations
1 parent b5b41eb commit cf2fff7

File tree

6 files changed

+257
-0
lines changed

6 files changed

+257
-0
lines changed

docs/api_reference.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,17 @@ Distributions
3535
histogram_approximation
3636

3737

38+
Model Transformations
39+
=====================
40+
41+
.. currentmodule:: pymc_experimental.model_transform
42+
.. autosummary::
43+
:toctree: generated/
44+
45+
conditioning.do
46+
conditioning.observe
47+
48+
3849
Utils
3950
=====
4051

pymc_experimental/model_transform/__init__.py

Whitespace-only changes.
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
from typing import Any, Dict, List, Sequence, Union
2+
3+
from pymc import Model
4+
from pymc.pytensorf import _replace_vars_in_graphs
5+
from pytensor.tensor import TensorVariable
6+
7+
from pymc_experimental.utils.model_fgraph import (
8+
ModelFreeRV,
9+
extract_dims,
10+
fgraph_from_model,
11+
model_from_fgraph,
12+
model_named,
13+
model_observed_rv,
14+
toposort_replace,
15+
)
16+
17+
18+
def observe(model: Model, vars_to_observations: Dict[Union["str", TensorVariable], Any]) -> Model:
19+
"""Convert free RVs to observed RVs.
20+
21+
Parameters
22+
----------
23+
model: PyMC Model
24+
vars_to_observations: Dict of variable or name to TensorLike
25+
Dictionary that maps model variables (or names) to observed values.
26+
Observed values must have a shape and data type that is compatible
27+
with the original model variable.
28+
29+
Returns
30+
-------
31+
new_model: PyMC model
32+
A distinct PyMC model with the relevant variables observed.
33+
All remaining variables are cloned and can be retrieved via `new_model["var_name"]`.
34+
35+
Examples
36+
--------
37+
38+
.. code-block:: python
39+
40+
import pymc as pm
41+
from pymc_experimental.model_transform.conditioning import observe
42+
43+
with pm.Model() as m:
44+
x = pm.Normal("x")
45+
y = pm.Normal("y", x)
46+
z = pm.Normal("z", y)
47+
48+
m_new = observe(m, {y: 0.5})
49+
50+
"""
51+
vars_to_observations = {
52+
model[var] if isinstance(var, str) else var: obs
53+
for var, obs in vars_to_observations.items()
54+
}
55+
56+
# Note: Since PyMC can infer logprob expressions we could also allow observing Deterministics
57+
if any(var not in model.free_RVs for var in vars_to_observations):
58+
raise ValueError(f"At least one var is not a free variable in the model")
59+
60+
fgraph, memo = fgraph_from_model(model)
61+
62+
replacements = {}
63+
for var, obs in vars_to_observations.items():
64+
model_free_rv = memo[var]
65+
66+
# Just a sanity check
67+
assert isinstance(model_free_rv.owner.op, ModelFreeRV)
68+
assert model_free_rv in fgraph.variables
69+
70+
rv, vv, *dims = model_free_rv.owner.inputs
71+
model_obs_rv = model_observed_rv(rv, rv.type.filter_variable(obs), *dims)
72+
replacements[model_free_rv] = model_obs_rv
73+
74+
toposort_replace(fgraph, tuple(replacements.items()))
75+
76+
return model_from_fgraph(fgraph)
77+
78+
79+
def replace_vars_in_graphs(graphs: Sequence[TensorVariable], replacements) -> List[TensorVariable]:
80+
def replacement_fn(var, inner_replacements):
81+
if var in replacements:
82+
inner_replacements[var] = replacements[var]
83+
84+
# Handle root inputs as those will never be passed to the replacement_fn
85+
for inp in var.owner.inputs:
86+
if inp.owner is None and inp in replacements:
87+
inner_replacements[inp] = replacements[inp]
88+
89+
return [var]
90+
91+
replaced_graphs, _ = _replace_vars_in_graphs(graphs=graphs, replacement_fn=replacement_fn)
92+
return replaced_graphs
93+
94+
95+
def do(model: Model, vars_to_interventions: Dict[Union["str", TensorVariable], Any]) -> Model:
96+
"""Replace model variables by intervention variables.
97+
98+
Parameters
99+
----------
100+
model: PyMC Model
101+
vars_to_interventions: Dict of variable or name to TensorLike
102+
Dictionary that maps model variables (or names) to intervention expressions.
103+
Intervention expressions must have a shape and data type that is compatible
104+
with the original model variable.
105+
106+
Returns
107+
-------
108+
new_model: PyMC model
109+
A distinct PyMC model with the relevant variables replaced by the intervention expressions.
110+
All remaining variables are cloned and can be retrieved via `new_model["var_name"]`.
111+
112+
Examples
113+
--------
114+
115+
.. code-block:: python
116+
117+
import pymc as pm
118+
from pymc_experimental.model_transform.conditioning import do
119+
120+
with pm.Model() as m:
121+
x = pm.Normal("x", 0, 1)
122+
y = pm.Normal("y", x, 1)
123+
z = pm.Normal("z", y + x, 1)
124+
125+
# Dummy posterior, same as calling `pm.sample`
126+
idata_m = az.from_dict({rv.name: [pm.draw(rv, draws=500)] for rv in [x, y, z]})
127+
128+
# Replace `y` by a constant `100.0`
129+
m_do = do(m, {y: 100.0})
130+
with m_do:
131+
idata_do = pm.sample_posterior_predictive(idata_m, var_names="z")
132+
133+
"""
134+
do_mapping = {}
135+
for var, obs in vars_to_interventions.items():
136+
if isinstance(var, str):
137+
var = model[var]
138+
do_mapping[var] = var.type.filter_variable(obs)
139+
140+
if any(var not in (model.basic_RVs + model.deterministics) for var in do_mapping):
141+
raise ValueError(f"At least one var is not a variable or deterministic in the model")
142+
143+
fgraph, memo = fgraph_from_model(model)
144+
145+
# We need the interventions defined in terms of the IR fgraph representation,
146+
# In case they reference other variables in the model
147+
ir_interventions = replace_vars_in_graphs(list(do_mapping.values()), replacements=memo)
148+
149+
replacements = {}
150+
for var, intervention in zip(do_mapping, ir_interventions):
151+
model_var = memo[var]
152+
153+
# Just a sanity check
154+
assert model_var in fgraph.variables
155+
156+
intervention.name = model_var.name
157+
dims = extract_dims(model_var)
158+
new_var = model_named(intervention, *dims)
159+
160+
replacements[model_var] = new_var
161+
162+
# Replace variables by interventions
163+
toposort_replace(fgraph, tuple(replacements.items()))
164+
165+
return model_from_fgraph(fgraph)

pymc_experimental/tests/model_transform/__init__.py

Whitespace-only changes.
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import numpy as np
2+
import pymc as pm
3+
4+
from pymc_experimental.model_transform.conditioning import do, observe
5+
6+
7+
def test_observe():
8+
with pm.Model() as m_old:
9+
x = pm.Normal("x")
10+
y = pm.Normal("y", x)
11+
z = pm.Normal("z", y)
12+
13+
m_new = observe(m_old, {y: 0.5})
14+
15+
assert len(m_new.free_RVs) == 2
16+
assert len(m_new.observed_RVs) == 1
17+
assert m_new["x"] in m_new.free_RVs
18+
assert m_new["y"] in m_new.observed_RVs
19+
assert m_new["z"] in m_new.free_RVs
20+
21+
np.testing.assert_allclose(
22+
m_old.compile_logp()({"x": 0.9, "y": 0.5, "z": 1.4}),
23+
m_new.compile_logp()({"x": 0.9, "z": 1.4}),
24+
)
25+
26+
# Test two substitutions
27+
m_new = observe(m_old, {y: 0.5, z: 1.4})
28+
29+
assert len(m_new.free_RVs) == 1
30+
assert len(m_new.observed_RVs) == 2
31+
assert m_new["x"] in m_new.free_RVs
32+
assert m_new["y"] in m_new.observed_RVs
33+
assert m_new["z"] in m_new.observed_RVs
34+
35+
np.testing.assert_allclose(
36+
m_old.compile_logp()({"x": 0.9, "y": 0.5, "z": 1.4}),
37+
m_new.compile_logp()({"x": 0.9}),
38+
)
39+
40+
41+
def test_do():
42+
with pm.Model() as m_old:
43+
x = pm.Normal("x", 0, 1e-3)
44+
y = pm.Normal("y", x, 1e-3)
45+
z = pm.Normal("z", y + x, 1e-3)
46+
47+
assert -5 < pm.draw(z) < 5
48+
49+
m_new = do(m_old, {y: x + 100})
50+
51+
assert len(m_new.free_RVs) == 2
52+
assert m_new["x"] in m_new.free_RVs
53+
assert m_new["y"] in m_new.named_vars.values()
54+
assert m_new["z"] in m_new.free_RVs
55+
56+
assert 95 < pm.draw(m_new["z"]) < 105
57+
58+
# Test two substitutions
59+
with m_old:
60+
switch = pm.MutableData("switch", 1)
61+
m_new = do(m_old, {y: 100 * switch, x: 100 * switch})
62+
63+
assert len(m_new.free_RVs) == 1
64+
assert m_new["x"] in m_new.named_vars.values()
65+
assert m_new["y"] in m_new.named_vars.values()
66+
assert m_new["z"] in m_new.free_RVs
67+
68+
assert 195 < pm.draw(m_new["z"]) < 205
69+
with m_new:
70+
pm.set_data({"switch": 0})
71+
assert -5 < pm.draw(m_new["z"]) < 5

pymc_experimental/utils/model_fgraph.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,13 @@ def clone_model(model: Model) -> Tuple[Model]:
326326
327327
"""
328328
return model_from_fgraph(fgraph_from_model(model)[0])
329+
330+
331+
def extract_dims(var) -> Tuple:
332+
dims = ()
333+
if isinstance(var, ModelVar):
334+
if isinstance(var, ModelValuedVar):
335+
dims = var.inputs[2:]
336+
else:
337+
dims = var.inputs[1:]
338+
return dims

0 commit comments

Comments
 (0)