|
25 | 25 |
|
26 | 26 | from pytensor import tensor as pt
|
27 | 27 | from pytensor.compile.builders import OpFromGraph
|
28 |
| -from pytensor.graph import node_rewriter |
| 28 | +from pytensor.graph import FunctionGraph, node_rewriter |
29 | 29 | from pytensor.graph.basic import Node, Variable
|
30 | 30 | from pytensor.graph.replace import clone_replace
|
31 | 31 | from pytensor.graph.rewriting.basic import in2out
|
32 | 32 | from pytensor.graph.utils import MetaType
|
33 | 33 | from pytensor.tensor.basic import as_tensor_variable
|
34 | 34 | from pytensor.tensor.random.op import RandomVariable
|
| 35 | +from pytensor.tensor.random.rewriting import local_subtensor_rv_lift |
35 | 36 | from pytensor.tensor.random.utils import normalize_size_param
|
36 | 37 | from pytensor.tensor.var import TensorVariable
|
37 | 38 | from typing_extensions import TypeAlias
|
|
49 | 50 | )
|
50 | 51 | from pymc.exceptions import BlockModelAccessError
|
51 | 52 | from pymc.logprob.abstract import MeasurableVariable, _icdf, _logcdf, _logprob
|
| 53 | +from pymc.logprob.basic import logp |
52 | 54 | from pymc.logprob.rewriting import logprob_rewrites_db
|
53 | 55 | from pymc.model import BlockModelAccess
|
54 | 56 | from pymc.printing import str_for_dist
|
@@ -1148,3 +1150,145 @@ def logcdf(value, c):
|
1148 | 1150 | -np.inf,
|
1149 | 1151 | 0,
|
1150 | 1152 | )
|
| 1153 | + |
| 1154 | + |
| 1155 | +class PartialObservedRV(SymbolicRandomVariable): |
| 1156 | + """RandomVariable with partially observed subspace, as indicated by a boolean mask. |
| 1157 | +
|
| 1158 | + See `create_partial_observed_rv` for more details. |
| 1159 | + """ |
| 1160 | + |
| 1161 | + |
| 1162 | +def create_partial_observed_rv( |
| 1163 | + rv: TensorVariable, |
| 1164 | + mask: Union[np.ndarray, TensorVariable], |
| 1165 | +) -> Tuple[ |
| 1166 | + Tuple[TensorVariable, TensorVariable], Tuple[TensorVariable, TensorVariable], TensorVariable |
| 1167 | +]: |
| 1168 | + """Separate observed and unobserved components of a RandomVariable. |
| 1169 | +
|
| 1170 | + This function may return two independent RandomVariables or, if not possible, |
| 1171 | + two variables from a common `PartialObservedRV` node |
| 1172 | +
|
| 1173 | + Parameters |
| 1174 | + ---------- |
| 1175 | + rv : TensorVariable |
| 1176 | + mask : tensor_like |
| 1177 | + Constant or variable boolean mask. True entries correspond to components of the variable that are not observed. |
| 1178 | +
|
| 1179 | + Returns |
| 1180 | + ------- |
| 1181 | + observed_rv and mask : Tuple of TensorVariable |
| 1182 | + The observed component of the RV and respective indexing mask |
| 1183 | + unobserved_rv and mask : Tuple of TensorVariable |
| 1184 | + The unobserved component of the RV and respective indexing mask |
| 1185 | + joined_rv : TensorVariable |
| 1186 | + The symbolic join of the observed and unobserved components. |
| 1187 | + """ |
| 1188 | + if not mask.dtype == "bool": |
| 1189 | + raise ValueError( |
| 1190 | + f"mask must be an array or tensor of boolean dtype, got dtype: {mask.dtype}" |
| 1191 | + ) |
| 1192 | + |
| 1193 | + if mask.ndim > rv.ndim: |
| 1194 | + raise ValueError(f"mask can't have more dims than rv, got ndim: {mask.ndim}") |
| 1195 | + |
| 1196 | + antimask = ~mask |
| 1197 | + |
| 1198 | + can_rewrite = False |
| 1199 | + # Only pure RVs can be rewritten |
| 1200 | + if isinstance(rv.owner.op, RandomVariable): |
| 1201 | + ndim_supp = rv.owner.op.ndim_supp |
| 1202 | + |
| 1203 | + # All univariate RVs can be rewritten |
| 1204 | + if ndim_supp == 0: |
| 1205 | + can_rewrite = True |
| 1206 | + |
| 1207 | + # Multivariate RVs can be rewritten if masking does not split within support dimensions |
| 1208 | + else: |
| 1209 | + batch_dims = rv.type.ndim - ndim_supp |
| 1210 | + constant_mask = getattr(as_tensor_variable(mask), "data", None) |
| 1211 | + |
| 1212 | + # Indexing does not overlap with core dimensions |
| 1213 | + if mask.ndim <= batch_dims: |
| 1214 | + can_rewrite = True |
| 1215 | + |
| 1216 | + # Try to handle special case where mask is constant across support dimensions, |
| 1217 | + # TODO: This could be done by the rewrite itself |
| 1218 | + elif constant_mask is not None: |
| 1219 | + # We check if a constant_mask that only keeps the first entry of each support dim |
| 1220 | + # is equivalent to the original one after re-expanding. |
| 1221 | + trimmed_mask = constant_mask[(...,) + (0,) * ndim_supp] |
| 1222 | + expanded_mask = np.broadcast_to( |
| 1223 | + np.expand_dims(trimmed_mask, axis=tuple(range(-ndim_supp, 0))), |
| 1224 | + shape=constant_mask.shape, |
| 1225 | + ) |
| 1226 | + if np.array_equal(constant_mask, expanded_mask): |
| 1227 | + mask = trimmed_mask |
| 1228 | + antimask = ~trimmed_mask |
| 1229 | + can_rewrite = True |
| 1230 | + |
| 1231 | + if can_rewrite: |
| 1232 | + # Rewrite doesn't work with boolean masks. Should be fixed after https://github.com/pymc-devs/pytensor/pull/329 |
| 1233 | + mask, antimask = mask.nonzero(), antimask.nonzero() |
| 1234 | + |
| 1235 | + masked_rv = rv[mask] |
| 1236 | + fgraph = FunctionGraph(outputs=[masked_rv], clone=False) |
| 1237 | + [unobserved_rv] = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner) |
| 1238 | + |
| 1239 | + antimasked_rv = rv[antimask] |
| 1240 | + fgraph = FunctionGraph(outputs=[antimasked_rv], clone=False) |
| 1241 | + [observed_rv] = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner) |
| 1242 | + |
| 1243 | + # Make a clone of the observedRV, with a distinct rng so that observed and |
| 1244 | + # unobserved are never treated as equivalent (and mergeable) nodes by pytensor. |
| 1245 | + _, size, _, *inps = observed_rv.owner.inputs |
| 1246 | + observed_rv = observed_rv.owner.op(*inps, size=size) |
| 1247 | + |
| 1248 | + # For all other cases use the more general PartialObservedRV |
| 1249 | + else: |
| 1250 | + # The symbolic graph simply splits the observed and unobserved components, |
| 1251 | + # so they can be given separate values. |
| 1252 | + dist_, mask_ = rv.type(), as_tensor_variable(mask).type() |
| 1253 | + observed_rv_, unobserved_rv_ = dist_[~mask_], dist_[mask_] |
| 1254 | + |
| 1255 | + observed_rv, unobserved_rv = PartialObservedRV( |
| 1256 | + inputs=[dist_, mask_], |
| 1257 | + outputs=[observed_rv_, unobserved_rv_], |
| 1258 | + ndim_supp=rv.owner.op.ndim_supp, |
| 1259 | + )(rv, mask) |
| 1260 | + |
| 1261 | + joined_rv = pt.empty(rv.shape, dtype=rv.type.dtype) |
| 1262 | + joined_rv = pt.set_subtensor(joined_rv[mask], unobserved_rv) |
| 1263 | + joined_rv = pt.set_subtensor(joined_rv[antimask], observed_rv) |
| 1264 | + |
| 1265 | + return (observed_rv, antimask), (unobserved_rv, mask), joined_rv |
| 1266 | + |
| 1267 | + |
| 1268 | +@_logprob.register(PartialObservedRV) |
| 1269 | +def partial_observed_rv_logprob(op, values, dist, mask, **kwargs): |
| 1270 | + # For the logp, simply join the values |
| 1271 | + [obs_value, unobs_value] = values |
| 1272 | + antimask = ~mask |
| 1273 | + joined_value = pt.empty_like(dist) |
| 1274 | + joined_value = pt.set_subtensor(joined_value[mask], unobs_value) |
| 1275 | + joined_value = pt.set_subtensor(joined_value[antimask], obs_value) |
| 1276 | + joined_logp = logp(dist, joined_value) |
| 1277 | + |
| 1278 | + # If we have a univariate RV we can split apart the logp terms |
| 1279 | + if op.ndim_supp == 0: |
| 1280 | + return joined_logp[antimask], joined_logp[mask] |
| 1281 | + # Otherwise, we can't (always/ easily) split apart logp terms. |
| 1282 | + # We return the full logp for the observed value, and a 0-nd array for the unobserved value |
| 1283 | + else: |
| 1284 | + return joined_logp.ravel(), pt.zeros((0,), dtype=joined_logp.type.dtype) |
| 1285 | + |
| 1286 | + |
| 1287 | +@_moment.register(PartialObservedRV) |
| 1288 | +def partial_observed_rv_moment(op, partial_obs_rv, rv, mask): |
| 1289 | + # Unobserved output |
| 1290 | + if partial_obs_rv.owner.outputs.index(partial_obs_rv) == 1: |
| 1291 | + return moment(rv)[mask] |
| 1292 | + # Observed output |
| 1293 | + else: |
| 1294 | + return moment(rv)[~mask] |
0 commit comments