Skip to content

Commit 1da28b1

Browse files
committed
Extend support for automatic imputation
1 parent 12de79c commit 1da28b1

File tree

5 files changed

+442
-96
lines changed

5 files changed

+442
-96
lines changed

pymc/distributions/distribution.py

Lines changed: 145 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,14 @@
2525

2626
from pytensor import tensor as pt
2727
from pytensor.compile.builders import OpFromGraph
28-
from pytensor.graph import node_rewriter
28+
from pytensor.graph import FunctionGraph, node_rewriter
2929
from pytensor.graph.basic import Node, Variable
3030
from pytensor.graph.replace import clone_replace
3131
from pytensor.graph.rewriting.basic import in2out
3232
from pytensor.graph.utils import MetaType
3333
from pytensor.tensor.basic import as_tensor_variable
3434
from pytensor.tensor.random.op import RandomVariable
35+
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
3536
from pytensor.tensor.random.utils import normalize_size_param
3637
from pytensor.tensor.var import TensorVariable
3738
from typing_extensions import TypeAlias
@@ -49,6 +50,7 @@
4950
)
5051
from pymc.exceptions import BlockModelAccessError
5152
from pymc.logprob.abstract import MeasurableVariable, _icdf, _logcdf, _logprob
53+
from pymc.logprob.basic import logp
5254
from pymc.logprob.rewriting import logprob_rewrites_db
5355
from pymc.model import BlockModelAccess
5456
from pymc.printing import str_for_dist
@@ -1148,3 +1150,145 @@ def logcdf(value, c):
11481150
-np.inf,
11491151
0,
11501152
)
1153+
1154+
1155+
class PartialObservedRV(SymbolicRandomVariable):
1156+
"""RandomVariable with partially observed subspace, as indicated by a boolean mask.
1157+
1158+
See `create_partial_observed_rv` for more details.
1159+
"""
1160+
1161+
1162+
def create_partial_observed_rv(
1163+
rv: TensorVariable,
1164+
mask: Union[np.ndarray, TensorVariable],
1165+
) -> Tuple[
1166+
Tuple[TensorVariable, TensorVariable], Tuple[TensorVariable, TensorVariable], TensorVariable
1167+
]:
1168+
"""Separate observed and unobserved components of a RandomVariable.
1169+
1170+
This function may return two independent RandomVariables or, if not possible,
1171+
two variables from a common `PartialObservedRV` node
1172+
1173+
Parameters
1174+
----------
1175+
rv : TensorVariable
1176+
mask : tensor_like
1177+
Constant or variable boolean mask. True entries correspond to components of the variable that are not observed.
1178+
1179+
Returns
1180+
-------
1181+
observed_rv and mask : Tuple of TensorVariable
1182+
The observed component of the RV and respective indexing mask
1183+
unobserved_rv and mask : Tuple of TensorVariable
1184+
The unobserved component of the RV and respective indexing mask
1185+
joined_rv : TensorVariable
1186+
The symbolic join of the observed and unobserved components.
1187+
"""
1188+
if not mask.dtype == "bool":
1189+
raise ValueError(
1190+
f"mask must be an array or tensor of boolean dtype, got dtype: {mask.dtype}"
1191+
)
1192+
1193+
if mask.ndim > rv.ndim:
1194+
raise ValueError(f"mask can't have more dims than rv, got ndim: {mask.ndim}")
1195+
1196+
antimask = ~mask
1197+
1198+
can_rewrite = False
1199+
# Only pure RVs can be rewritten
1200+
if isinstance(rv.owner.op, RandomVariable):
1201+
ndim_supp = rv.owner.op.ndim_supp
1202+
1203+
# All univariate RVs can be rewritten
1204+
if ndim_supp == 0:
1205+
can_rewrite = True
1206+
1207+
# Multivariate RVs can be rewritten if masking does not split within support dimensions
1208+
else:
1209+
batch_dims = rv.type.ndim - ndim_supp
1210+
constant_mask = getattr(as_tensor_variable(mask), "data", None)
1211+
1212+
# Indexing does not overlap with core dimensions
1213+
if mask.ndim <= batch_dims:
1214+
can_rewrite = True
1215+
1216+
# Try to handle special case where mask is constant across support dimensions,
1217+
# TODO: This could be done by the rewrite itself
1218+
elif constant_mask is not None:
1219+
# We check if a constant_mask that only keeps the first entry of each support dim
1220+
# is equivalent to the original one after re-expanding.
1221+
trimmed_mask = constant_mask[(...,) + (0,) * ndim_supp]
1222+
expanded_mask = np.broadcast_to(
1223+
np.expand_dims(trimmed_mask, axis=tuple(range(-ndim_supp, 0))),
1224+
shape=constant_mask.shape,
1225+
)
1226+
if np.array_equal(constant_mask, expanded_mask):
1227+
mask = trimmed_mask
1228+
antimask = ~trimmed_mask
1229+
can_rewrite = True
1230+
1231+
if can_rewrite:
1232+
# Rewrite doesn't work with boolean masks. Should be fixed after https://github.com/pymc-devs/pytensor/pull/329
1233+
mask, antimask = mask.nonzero(), antimask.nonzero()
1234+
1235+
masked_rv = rv[mask]
1236+
fgraph = FunctionGraph(outputs=[masked_rv], clone=False)
1237+
[unobserved_rv] = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner)
1238+
1239+
antimasked_rv = rv[antimask]
1240+
fgraph = FunctionGraph(outputs=[antimasked_rv], clone=False)
1241+
[observed_rv] = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner)
1242+
1243+
# Make a clone of the observedRV, with a distinct rng so that observed and
1244+
# unobserved are never treated as equivalent (and mergeable) nodes by pytensor.
1245+
_, size, _, *inps = observed_rv.owner.inputs
1246+
observed_rv = observed_rv.owner.op(*inps, size=size)
1247+
1248+
# For all other cases use the more general PartialObservedRV
1249+
else:
1250+
# The symbolic graph simply splits the observed and unobserved components,
1251+
# so they can be given separate values.
1252+
dist_, mask_ = rv.type(), as_tensor_variable(mask).type()
1253+
observed_rv_, unobserved_rv_ = dist_[~mask_], dist_[mask_]
1254+
1255+
observed_rv, unobserved_rv = PartialObservedRV(
1256+
inputs=[dist_, mask_],
1257+
outputs=[observed_rv_, unobserved_rv_],
1258+
ndim_supp=rv.owner.op.ndim_supp,
1259+
)(rv, mask)
1260+
1261+
joined_rv = pt.empty(rv.shape, dtype=rv.type.dtype)
1262+
joined_rv = pt.set_subtensor(joined_rv[mask], unobserved_rv)
1263+
joined_rv = pt.set_subtensor(joined_rv[antimask], observed_rv)
1264+
1265+
return (observed_rv, antimask), (unobserved_rv, mask), joined_rv
1266+
1267+
1268+
@_logprob.register(PartialObservedRV)
1269+
def partial_observed_rv_logprob(op, values, dist, mask, **kwargs):
1270+
# For the logp, simply join the values
1271+
[obs_value, unobs_value] = values
1272+
antimask = ~mask
1273+
joined_value = pt.empty_like(dist)
1274+
joined_value = pt.set_subtensor(joined_value[mask], unobs_value)
1275+
joined_value = pt.set_subtensor(joined_value[antimask], obs_value)
1276+
joined_logp = logp(dist, joined_value)
1277+
1278+
# If we have a univariate RV we can split apart the logp terms
1279+
if op.ndim_supp == 0:
1280+
return joined_logp[antimask], joined_logp[mask]
1281+
# Otherwise, we can't (always/ easily) split apart logp terms.
1282+
# We return the full logp for the observed value, and a 0-nd array for the unobserved value
1283+
else:
1284+
return joined_logp.ravel(), pt.zeros((0,), dtype=joined_logp.type.dtype)
1285+
1286+
1287+
@_moment.register(PartialObservedRV)
1288+
def partial_observed_rv_moment(op, partial_obs_rv, rv, mask):
1289+
# Unobserved output
1290+
if partial_obs_rv.owner.outputs.index(partial_obs_rv) == 1:
1291+
return moment(rv)[mask]
1292+
# Observed output
1293+
else:
1294+
return moment(rv)[~mask]

pymc/model.py

Lines changed: 16 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,9 @@
4444
from pytensor.compile import DeepCopyOp, get_mode
4545
from pytensor.compile.sharedvalue import SharedVariable
4646
from pytensor.graph.basic import Constant, Variable, graph_inputs
47-
from pytensor.graph.fg import FunctionGraph
4847
from pytensor.scalar import Cast
4948
from pytensor.tensor.elemwise import Elemwise
5049
from pytensor.tensor.random.op import RandomVariable
51-
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
5250
from pytensor.tensor.random.type import RandomType
5351
from pytensor.tensor.sharedvar import ScalarSharedVariable
5452
from pytensor.tensor.var import TensorConstant, TensorVariable
@@ -1364,67 +1362,27 @@ def make_obs_var(
13641362
if total_size is not None:
13651363
raise ValueError("total_size is not compatible with imputed variables")
13661364

1367-
if not isinstance(rv_var.owner.op, RandomVariable):
1368-
raise NotImplementedError(
1369-
"Automatic inputation is only supported for univariate RandomVariables."
1370-
f" {rv_var} of type {type(rv_var.owner.op)} is not supported."
1371-
)
1372-
1373-
if rv_var.owner.op.ndim_supp > 0:
1374-
raise NotImplementedError(
1375-
f"Automatic inputation is only supported for univariate "
1376-
f"RandomVariables, but {rv_var} is multivariate"
1377-
)
1365+
from pymc.distributions.distribution import create_partial_observed_rv
13781366

1379-
# We can get a random variable comprised of only the unobserved
1380-
# entries by lifting the indices through the `RandomVariable` `Op`.
1367+
(
1368+
(observed_rv, observed_mask),
1369+
(unobserved_rv, _),
1370+
joined_rv,
1371+
) = create_partial_observed_rv(rv_var, mask)
1372+
observed_data = pt.as_tensor(data.data[observed_mask])
13811373

1382-
masked_rv_var = rv_var[mask.nonzero()]
1383-
1384-
fgraph = FunctionGraph(
1385-
[i for i in graph_inputs((masked_rv_var,)) if not isinstance(i, Constant)],
1386-
[masked_rv_var],
1387-
clone=False,
1388-
)
1374+
# Register ObservedRV corresponding to observed component
1375+
observed_rv.name = f"{name}_observed"
1376+
self.create_value_var(observed_rv, transform=None, value_var=observed_data)
1377+
self.add_named_variable(observed_rv)
1378+
self.observed_RVs.append(observed_rv)
13891379

1390-
(missing_rv_var,) = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner)
1380+
# Register FreeRV corresponding to unobserved components
1381+
self.register_rv(unobserved_rv, f"{name}_missing", transform=transform)
13911382

1392-
self.register_rv(missing_rv_var, f"{name}_missing", transform=transform)
1393-
1394-
# Now, we lift the non-missing observed values and produce a new
1395-
# `rv_var` that contains only those.
1396-
#
1397-
# The end result is two disjoint distributions: one for the missing
1398-
# values, and another for the non-missing values.
1399-
1400-
antimask_idx = (~mask).nonzero()
1401-
nonmissing_data = pt.as_tensor_variable(data[antimask_idx].data)
1402-
unmasked_rv_var = rv_var[antimask_idx]
1403-
unmasked_rv_var = unmasked_rv_var.owner.clone().default_output()
1404-
1405-
fgraph = FunctionGraph(
1406-
[i for i in graph_inputs((unmasked_rv_var,)) if not isinstance(i, Constant)],
1407-
[unmasked_rv_var],
1408-
clone=False,
1409-
)
1410-
(observed_rv_var,) = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner)
1411-
# Make a clone of the RV, but let it create a new rng so that observed and
1412-
# missing are not treated as equivalent nodes by pytensor. This would happen
1413-
# if the size of the masked and unmasked array happened to coincide
1414-
_, size, _, *inps = observed_rv_var.owner.inputs
1415-
observed_rv_var = observed_rv_var.owner.op(*inps, size=size, name=f"{name}_observed")
1416-
observed_rv_var.tag.observations = nonmissing_data
1417-
1418-
self.create_value_var(observed_rv_var, transform=None, value_var=nonmissing_data)
1419-
self.add_named_variable(observed_rv_var)
1420-
self.observed_RVs.append(observed_rv_var)
1421-
1422-
# Create deterministic that combines observed and missing
1383+
# Register Deterministic that combines observed and missing
14231384
# Note: This can widely increase memory consumption during sampling for large datasets
1424-
rv_var = pt.empty(data.shape, dtype=observed_rv_var.type.dtype)
1425-
rv_var = pt.set_subtensor(rv_var[mask.nonzero()], missing_rv_var)
1426-
rv_var = pt.set_subtensor(rv_var[antimask_idx], observed_rv_var)
1427-
rv_var = Deterministic(name, rv_var, self, dims)
1385+
rv_var = Deterministic(name, joined_rv, self, dims)
14281386

14291387
else:
14301388
if sps.issparse(data):

tests/backends/test_arviz.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,6 @@ def test_missing_data_model(self):
352352
# See https://github.com/pymc-devs/pymc/issues/5255
353353
assert inference_data.log_likelihood["y_observed"].shape == (2, 100, 3)
354354

355-
@pytest.mark.xfail(reason="Multivariate partial observed RVs not implemented for V4")
356355
def test_mv_missing_data_model(self):
357356
data = ma.masked_values([[1, 2], [2, 2], [-1, 4], [2, -1], [-1, -1]], value=-1)
358357

@@ -361,19 +360,25 @@ def test_mv_missing_data_model(self):
361360
mu = pm.Normal("mu", 0, 1, size=2)
362361
sd_dist = pm.HalfNormal.dist(1.0, size=2)
363362
# pylint: disable=unpacking-non-sequence
364-
chol, *_ = pm.LKJCholeskyCov("chol_cov", n=2, eta=1, sd_dist=sd_dist, compute_corr=True)
363+
chol, *_ = pm.LKJCholeskyCov("chol_cov", n=2, eta=1, sd_dist=sd_dist)
365364
# pylint: enable=unpacking-non-sequence
366365
with pytest.warns(ImputationWarning):
367366
y = pm.MvNormal("y", mu=mu, chol=chol, observed=data)
368-
inference_data = pm.sample(100, chains=2, return_inferencedata=True)
367+
inference_data = pm.sample(
368+
tune=10,
369+
draws=10,
370+
chains=2,
371+
step=pm.Metropolis(),
372+
idata_kwargs=dict(log_likelihood=True),
373+
)
369374

370375
# make sure that data is really missing
371-
assert isinstance(y.owner.op, (AdvancedIncSubtensor, AdvancedIncSubtensor1))
376+
assert isinstance(y.owner.inputs[0].owner.op, (AdvancedIncSubtensor, AdvancedIncSubtensor1))
372377

373378
test_dict = {
374379
"posterior": ["mu", "chol_cov"],
375-
"observed_data": ["y"],
376-
"log_likelihood": ["y"],
380+
"observed_data": ["y_observed"],
381+
"log_likelihood": ["y_observed"],
377382
}
378383
fails = check_multiple_attrs(test_dict, inference_data)
379384
assert not fails

0 commit comments

Comments
 (0)