Skip to content

Commit ff32a66

Browse files
committed
Implement observe and do model transformations
1 parent 1d928af commit ff32a66

File tree

6 files changed

+348
-0
lines changed

6 files changed

+348
-0
lines changed

docs/api_reference.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,17 @@ Distributions
3535
histogram_approximation
3636

3737

38+
Model Transformations
39+
=====================
40+
41+
.. currentmodule:: pymc_experimental.model_transform
42+
.. autosummary::
43+
:toctree: generated/
44+
45+
conditioning.do
46+
conditioning.observe
47+
48+
3849
Utils
3950
=====
4051

pymc_experimental/model_transform/__init__.py

Whitespace-only changes.
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
from typing import Any, Dict, List, Sequence, Union
2+
3+
from pymc import Model
4+
from pymc.pytensorf import _replace_vars_in_graphs
5+
from pytensor.tensor import TensorVariable
6+
7+
from pymc_experimental.utils.model_fgraph import (
8+
ModelDeterministic,
9+
ModelFreeRV,
10+
extract_dims,
11+
fgraph_from_model,
12+
model_from_fgraph,
13+
model_named,
14+
model_observed_rv,
15+
toposort_replace,
16+
)
17+
18+
19+
def observe(model: Model, vars_to_observations: Dict[Union["str", TensorVariable], Any]) -> Model:
20+
"""Convert free RVs or Deterministics to observed RVs.
21+
22+
Parameters
23+
----------
24+
model: PyMC Model
25+
vars_to_observations: Dict of variable or name to TensorLike
26+
Dictionary that maps model variables (or names) to observed values.
27+
Observed values must have a shape and data type that is compatible
28+
with the original model variable.
29+
30+
Returns
31+
-------
32+
new_model: PyMC model
33+
A distinct PyMC model with the relevant variables observed.
34+
All remaining variables are cloned and can be retrieved via `new_model["var_name"]`.
35+
36+
Examples
37+
--------
38+
39+
.. code-block:: python
40+
41+
import pymc as pm
42+
from pymc_experimental.model_transform.conditioning import observe
43+
44+
with pm.Model() as m:
45+
x = pm.Normal("x")
46+
y = pm.Normal("y", x)
47+
z = pm.Normal("z", y)
48+
49+
m_new = observe(m, {y: 0.5})
50+
51+
Deterministic variables can also be observed.
52+
This relies on PyMC ability to infer the logp of the underlying expression
53+
54+
.. code-block:: python
55+
56+
import pymc as pm
57+
from pymc_experimental.model_transform.conditioning import observe
58+
59+
with pm.Model() as m:
60+
x = pm.Normal("x")
61+
y = pm.Normal.dist(x, shape=(5,))
62+
y_censored = pm.Deterministic("y_censored", pm.math.clip(y, -1, 1))
63+
64+
new_m = observe(m, {y_censored: [0.9, 0.5, 0.3, 1, 1]})
65+
66+
67+
"""
68+
vars_to_observations = {
69+
model[var] if isinstance(var, str) else var: obs
70+
for var, obs in vars_to_observations.items()
71+
}
72+
73+
valid_model_vars = set(model.free_RVs + model.deterministics)
74+
if any(var not in valid_model_vars for var in vars_to_observations):
75+
raise ValueError(f"At least one var is not a free variable or deterministic in the model")
76+
77+
fgraph, memo = fgraph_from_model(model)
78+
79+
replacements = {}
80+
for var, obs in vars_to_observations.items():
81+
model_var = memo[var]
82+
83+
# Just a sanity check
84+
assert isinstance(model_var.owner.op, (ModelFreeRV, ModelDeterministic))
85+
assert model_var in fgraph.variables
86+
87+
var = model_var.owner.inputs[0]
88+
var.name = model_var.name
89+
dims = extract_dims(var)
90+
model_obs_rv = model_observed_rv(var, var.type.filter_variable(obs), *dims)
91+
replacements[model_var] = model_obs_rv
92+
93+
toposort_replace(fgraph, tuple(replacements.items()))
94+
95+
return model_from_fgraph(fgraph)
96+
97+
98+
def replace_vars_in_graphs(graphs: Sequence[TensorVariable], replacements) -> List[TensorVariable]:
99+
def replacement_fn(var, inner_replacements):
100+
if var in replacements:
101+
inner_replacements[var] = replacements[var]
102+
103+
# Handle root inputs as those will never be passed to the replacement_fn
104+
for inp in var.owner.inputs:
105+
if inp.owner is None and inp in replacements:
106+
inner_replacements[inp] = replacements[inp]
107+
108+
return [var]
109+
110+
replaced_graphs, _ = _replace_vars_in_graphs(graphs=graphs, replacement_fn=replacement_fn)
111+
return replaced_graphs
112+
113+
114+
def do(model: Model, vars_to_interventions: Dict[Union["str", TensorVariable], Any]) -> Model:
115+
"""Replace model variables by intervention variables.
116+
117+
Parameters
118+
----------
119+
model: PyMC Model
120+
vars_to_interventions: Dict of variable or name to TensorLike
121+
Dictionary that maps model variables (or names) to intervention expressions.
122+
Intervention expressions must have a shape and data type that is compatible
123+
with the original model variable.
124+
125+
Returns
126+
-------
127+
new_model: PyMC model
128+
A distinct PyMC model with the relevant variables replaced by the intervention expressions.
129+
All remaining variables are cloned and can be retrieved via `new_model["var_name"]`.
130+
131+
Examples
132+
--------
133+
134+
.. code-block:: python
135+
136+
import pymc as pm
137+
from pymc_experimental.model_transform.conditioning import do
138+
139+
with pm.Model() as m:
140+
x = pm.Normal("x", 0, 1)
141+
y = pm.Normal("y", x, 1)
142+
z = pm.Normal("z", y + x, 1)
143+
144+
# Dummy posterior, same as calling `pm.sample`
145+
idata_m = az.from_dict({rv.name: [pm.draw(rv, draws=500)] for rv in [x, y, z]})
146+
147+
# Replace `y` by a constant `100.0`
148+
m_do = do(m, {y: 100.0})
149+
with m_do:
150+
idata_do = pm.sample_posterior_predictive(idata_m, var_names="z")
151+
152+
"""
153+
do_mapping = {}
154+
for var, obs in vars_to_interventions.items():
155+
if isinstance(var, str):
156+
var = model[var]
157+
do_mapping[var] = var.type.filter_variable(obs)
158+
159+
if any(var not in (model.basic_RVs + model.deterministics) for var in do_mapping):
160+
raise ValueError(f"At least one var is not a variable or deterministic in the model")
161+
162+
fgraph, memo = fgraph_from_model(model)
163+
164+
# We need the interventions defined in terms of the IR fgraph representation,
165+
# In case they reference other variables in the model
166+
ir_interventions = replace_vars_in_graphs(list(do_mapping.values()), replacements=memo)
167+
168+
replacements = {}
169+
for var, intervention in zip(do_mapping, ir_interventions):
170+
model_var = memo[var]
171+
172+
# Just a sanity check
173+
assert model_var in fgraph.variables
174+
175+
intervention.name = model_var.name
176+
dims = extract_dims(model_var)
177+
new_var = model_named(intervention, *dims)
178+
179+
replacements[model_var] = new_var
180+
181+
# Replace variables by interventions
182+
toposort_replace(fgraph, tuple(replacements.items()))
183+
184+
return model_from_fgraph(fgraph)

pymc_experimental/tests/model_transform/__init__.py

Whitespace-only changes.
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import arviz as az
2+
import numpy as np
3+
import pymc as pm
4+
from pymc.variational.minibatch_rv import create_minibatch_rv
5+
from pytensor import config
6+
7+
from pymc_experimental.model_transform.conditioning import do, observe
8+
9+
10+
def test_observe():
11+
with pm.Model() as m_old:
12+
x = pm.Normal("x")
13+
y = pm.Normal("y", x)
14+
z = pm.Normal("z", y)
15+
16+
m_new = observe(m_old, {y: 0.5})
17+
18+
assert len(m_new.free_RVs) == 2
19+
assert len(m_new.observed_RVs) == 1
20+
assert m_new["x"] in m_new.free_RVs
21+
assert m_new["y"] in m_new.observed_RVs
22+
assert m_new["z"] in m_new.free_RVs
23+
24+
np.testing.assert_allclose(
25+
m_old.compile_logp()({"x": 0.9, "y": 0.5, "z": 1.4}),
26+
m_new.compile_logp()({"x": 0.9, "z": 1.4}),
27+
)
28+
29+
# Test two substitutions
30+
m_new = observe(m_old, {y: 0.5, z: 1.4})
31+
32+
assert len(m_new.free_RVs) == 1
33+
assert len(m_new.observed_RVs) == 2
34+
assert m_new["x"] in m_new.free_RVs
35+
assert m_new["y"] in m_new.observed_RVs
36+
assert m_new["z"] in m_new.observed_RVs
37+
38+
np.testing.assert_allclose(
39+
m_old.compile_logp()({"x": 0.9, "y": 0.5, "z": 1.4}),
40+
m_new.compile_logp()({"x": 0.9}),
41+
)
42+
43+
44+
def test_observe_minibatch():
45+
data = np.zeros((100,), dtype=config.floatX)
46+
batch_size = 10
47+
with pm.Model() as m_old:
48+
x = pm.Normal("x")
49+
y = pm.Normal("y", x)
50+
# Minibatch RVs are usually created with `total_size` kwarg
51+
z_raw = pm.Normal.dist(y, shape=batch_size)
52+
mb_z = create_minibatch_rv(z_raw, total_size=data.shape)
53+
m_old.register_rv(mb_z, name="mb_z")
54+
55+
mb_data = pm.Minibatch(data, batch_size=batch_size)
56+
m_new = observe(m_old, {mb_z: mb_data})
57+
58+
assert len(m_new.free_RVs) == 2
59+
assert len(m_new.observed_RVs) == 1
60+
assert m_new["x"] in m_new.free_RVs
61+
assert m_new["y"] in m_new.free_RVs
62+
assert m_new["mb_z"] in m_new.observed_RVs
63+
64+
np.testing.assert_allclose(
65+
m_old.compile_logp()({"x": 0.9, "y": 0.5, "mb_z": np.zeros(10)}),
66+
m_new.compile_logp()({"x": 0.9, "y": 0.5}),
67+
)
68+
69+
70+
def test_observe_deterministic():
71+
y_censored_obs = np.array([0.9, 0.5, 0.3, 1, 1], dtype=config.floatX)
72+
73+
with pm.Model() as m_old:
74+
x = pm.Normal("x")
75+
y = pm.Normal.dist(x, shape=(5,))
76+
y_censored = pm.Deterministic("y_censored", pm.math.clip(y, -1, 1))
77+
78+
m_new = observe(m_old, {y_censored: y_censored_obs})
79+
80+
with pm.Model() as m_ref:
81+
x = pm.Normal("x")
82+
pm.Censored("y_censored", pm.Normal.dist(x), lower=-1, upper=1, observed=y_censored_obs)
83+
84+
np.testing.assert_allclose(
85+
m_new.compile_logp()({"x": 0.9}),
86+
m_ref.compile_logp()({"x": 0.9}),
87+
)
88+
89+
90+
def test_do():
91+
with pm.Model() as m_old:
92+
x = pm.Normal("x", 0, 1e-3)
93+
y = pm.Normal("y", x, 1e-3)
94+
z = pm.Normal("z", y + x, 1e-3)
95+
96+
assert -5 < pm.draw(z) < 5
97+
98+
m_new = do(m_old, {y: x + 100})
99+
100+
assert len(m_new.free_RVs) == 2
101+
assert m_new["x"] in m_new.free_RVs
102+
assert m_new["y"] in m_new.named_vars.values()
103+
assert m_new["z"] in m_new.free_RVs
104+
105+
assert 95 < pm.draw(m_new["z"]) < 105
106+
107+
# Test two substitutions
108+
with m_old:
109+
switch = pm.MutableData("switch", 1)
110+
m_new = do(m_old, {y: 100 * switch, x: 100 * switch})
111+
112+
assert len(m_new.free_RVs) == 1
113+
assert m_new["x"] in m_new.named_vars.values()
114+
assert m_new["y"] in m_new.named_vars.values()
115+
assert m_new["z"] in m_new.free_RVs
116+
117+
assert 195 < pm.draw(m_new["z"]) < 205
118+
with m_new:
119+
pm.set_data({"switch": 0})
120+
assert -5 < pm.draw(m_new["z"]) < 5
121+
122+
123+
def test_do_posterior_predictive():
124+
with pm.Model() as m:
125+
x = pm.Normal("x", 0, 1)
126+
y = pm.Normal("y", x, 1)
127+
z = pm.Normal("z", y + x, 1e-3)
128+
129+
# Dummy posterior
130+
idata_m = az.from_dict(
131+
{
132+
"x": np.full((2, 500), 25),
133+
"y": np.full((2, 500), np.nan),
134+
"z": np.full((2, 500), np.nan),
135+
}
136+
)
137+
138+
# Replace `y` by a constant `100.0`
139+
m_do = do(m, {y: 100.0})
140+
with m_do:
141+
idata_do = pm.sample_posterior_predictive(idata_m, var_names="z")
142+
143+
assert 120 < idata_do.posterior_predictive["z"].mean() < 130

pymc_experimental/utils/model_fgraph.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,13 @@ def clone_model(model: Model) -> Tuple[Model]:
326326
327327
"""
328328
return model_from_fgraph(fgraph_from_model(model)[0])
329+
330+
331+
def extract_dims(var) -> Tuple:
332+
dims = ()
333+
if isinstance(var, ModelVar):
334+
if isinstance(var, ModelValuedVar):
335+
dims = var.inputs[2:]
336+
else:
337+
dims = var.inputs[1:]
338+
return dims

0 commit comments

Comments
 (0)