|
| 1 | +# Copyright 2024 The PyMC Developers |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +from pytensor.graph.basic import Variable |
| 15 | +from pytensor.graph.rewriting.basic import node_rewriter |
| 16 | +from pytensor.tensor import eq |
| 17 | +from pytensor.tensor.subtensor import ( |
| 18 | + AdvancedIncSubtensor, |
| 19 | + AdvancedIncSubtensor1, |
| 20 | + IncSubtensor, |
| 21 | + indices_from_subtensor, |
| 22 | +) |
| 23 | +from pytensor.tensor.type import TensorType |
| 24 | +from pytensor.tensor.type_other import NoneTypeT |
| 25 | + |
| 26 | +from pymc.logprob.abstract import MeasurableOp, _logprob, _logprob_helper |
| 27 | +from pymc.logprob.checks import MeasurableCheckAndRaise |
| 28 | +from pymc.logprob.rewriting import measurable_ir_rewrites_db |
| 29 | +from pymc.logprob.utils import ( |
| 30 | + check_potential_measurability, |
| 31 | + dirac_delta, |
| 32 | + filter_measurable_variables, |
| 33 | +) |
| 34 | + |
| 35 | + |
| 36 | +class MeasurableSetSubtensor(IncSubtensor, MeasurableOp): |
| 37 | + """Measurable SetSubtensor Op.""" |
| 38 | + |
| 39 | + def __str__(self): |
| 40 | + return f"Measurable{super().__str__()}" |
| 41 | + |
| 42 | + |
| 43 | +class MeasurableAdvancedSetSubtensor(AdvancedIncSubtensor, MeasurableOp): |
| 44 | + """Measurable AdvancedSetSubtensor Op.""" |
| 45 | + |
| 46 | + def __str__(self): |
| 47 | + return f"Measurable{super().__str__()}" |
| 48 | + |
| 49 | + |
| 50 | +set_subtensor_does_not_broadcast = MeasurableCheckAndRaise( |
| 51 | + exc_type=NotImplementedError, |
| 52 | + msg="Measurable SetSubtensor not supported when set value is broadcasted.", |
| 53 | +) |
| 54 | + |
| 55 | + |
| 56 | +@node_rewriter(tracks=[IncSubtensor, AdvancedIncSubtensor1, AdvancedIncSubtensor]) |
| 57 | +def find_measurable_set_subtensor(fgraph, node) -> list | None: |
| 58 | + """Find `SetSubtensor` for which a `logprob` can be computed.""" |
| 59 | + if isinstance(node.op, MeasurableOp): |
| 60 | + return None |
| 61 | + |
| 62 | + if not node.op.set_instead_of_inc: |
| 63 | + return None |
| 64 | + |
| 65 | + x, y, *idx_elements = node.inputs |
| 66 | + |
| 67 | + measurable_inputs = filter_measurable_variables([x, y]) |
| 68 | + |
| 69 | + if y not in measurable_inputs: |
| 70 | + return None |
| 71 | + |
| 72 | + if x not in measurable_inputs: |
| 73 | + # x is potentially measurable, wait for it's logprob IR to be inferred |
| 74 | + if check_potential_measurability([x]): |
| 75 | + return None |
| 76 | + # x has no link to measurable variables, so it's value should be constant |
| 77 | + else: |
| 78 | + x = dirac_delta(x, rtol=0, atol=0) |
| 79 | + |
| 80 | + if check_potential_measurability(idx_elements): |
| 81 | + return None |
| 82 | + |
| 83 | + measurable_class: type[MeasurableSetSubtensor | MeasurableAdvancedSetSubtensor] |
| 84 | + if isinstance(node.op, IncSubtensor): |
| 85 | + measurable_class = MeasurableSetSubtensor |
| 86 | + idx = indices_from_subtensor(idx_elements, node.op.idx_list) |
| 87 | + else: |
| 88 | + measurable_class = MeasurableAdvancedSetSubtensor |
| 89 | + idx = tuple(idx_elements) |
| 90 | + |
| 91 | + # Check that y is not certainly broadcasted. |
| 92 | + indexed_block = x[idx] |
| 93 | + missing_y_dims = indexed_block.type.ndim - y.type.ndim |
| 94 | + y_bcast = [True] * missing_y_dims + list(y.type.broadcastable) |
| 95 | + if any( |
| 96 | + y_dim_bcast and indexed_block_dim_len not in (None, 1) |
| 97 | + for y_dim_bcast, indexed_block_dim_len in zip( |
| 98 | + y_bcast, indexed_block.type.shape, strict=True |
| 99 | + ) |
| 100 | + ): |
| 101 | + return None |
| 102 | + |
| 103 | + measurable_set_subtensor = measurable_class(**node.op._props_dict())(x, y, *idx_elements) |
| 104 | + |
| 105 | + # Often with indexing we don't know the static shape of the indexed block. |
| 106 | + # And, what's more, the indexing operations actually support runtime broadcasting. |
| 107 | + # As the logp is not valid under broadcasting, we have to add a runtime check. |
| 108 | + # This will hopefully be removed during shape inference when not violated. |
| 109 | + potential_broadcasted_dims = [ |
| 110 | + i |
| 111 | + for i, (y_bcast_dim, indexed_block_dim_len) in enumerate( |
| 112 | + zip(y_bcast, indexed_block.type.shape) |
| 113 | + ) |
| 114 | + if y_bcast_dim and indexed_block_dim_len is None |
| 115 | + ] |
| 116 | + if potential_broadcasted_dims: |
| 117 | + indexed_block_shape = tuple(indexed_block.shape) |
| 118 | + measurable_set_subtensor = set_subtensor_does_not_broadcast( |
| 119 | + measurable_set_subtensor, |
| 120 | + *(eq(indexed_block_shape[i], 1) for i in potential_broadcasted_dims), |
| 121 | + ) |
| 122 | + |
| 123 | + return [measurable_set_subtensor] |
| 124 | + |
| 125 | + |
| 126 | +measurable_ir_rewrites_db.register( |
| 127 | + find_measurable_set_subtensor.__name__, |
| 128 | + find_measurable_set_subtensor, |
| 129 | + "basic", |
| 130 | + "set_subtensor", |
| 131 | +) |
| 132 | + |
| 133 | + |
| 134 | +def indexed_dims(idx) -> list[int | None]: |
| 135 | + """Return the indices of the dimensions of the indexed tensor that are being indexed.""" |
| 136 | + dims: list[int | None] = [] |
| 137 | + idx_counter = 0 |
| 138 | + for idx_elem in idx: |
| 139 | + if isinstance(idx_elem, Variable) and isinstance(idx_elem.type, NoneTypeT): |
| 140 | + # None in indexes correspond to newaxis, and don't map to any existing dimension |
| 141 | + dims.append(None) |
| 142 | + |
| 143 | + elif ( |
| 144 | + isinstance(idx_elem, Variable) |
| 145 | + and isinstance(idx_elem.type, TensorType) |
| 146 | + and idx_elem.type.dtype == "bool" |
| 147 | + ): |
| 148 | + # Boolean indexes map to as many dimensions as the mask has |
| 149 | + for i in range(idx_elem.type.ndim): |
| 150 | + dims.append(idx_counter) |
| 151 | + idx_counter += 1 |
| 152 | + else: |
| 153 | + dims.append(idx_counter) |
| 154 | + idx_counter += 1 |
| 155 | + |
| 156 | + return dims |
| 157 | + |
| 158 | + |
| 159 | +@_logprob.register(MeasurableSetSubtensor) |
| 160 | +@_logprob.register(MeasurableAdvancedSetSubtensor) |
| 161 | +def logprob_setsubtensor(op, values, x, y, *idx_elements, **kwargs): |
| 162 | + """Compute the log-likelihood graph for a `SetSubtensor`. |
| 163 | +
|
| 164 | + For a generative graph like: |
| 165 | + o = zeros(2) |
| 166 | + x = o[0].set(X) |
| 167 | + y = x[1].set(Y) |
| 168 | +
|
| 169 | + The log-likelihood graph is: |
| 170 | + logp(y, value) = ( |
| 171 | + logp(x, value) |
| 172 | + [1].set(logp(y, value[1])) |
| 173 | + ) |
| 174 | +
|
| 175 | + Unrolling the logp(x, value) gives: |
| 176 | + logp(y, value) = ( |
| 177 | + DiracDelta(zeros(2), value) # Irrelevant if all entries are set |
| 178 | + [0].set(logp(x, value[0])) |
| 179 | + [1].set(logp(y, value[1])) |
| 180 | + ) |
| 181 | + """ |
| 182 | + [value] = values |
| 183 | + if isinstance(op, MeasurableSetSubtensor): |
| 184 | + # For basic indexing we have to recreate the index from the input list |
| 185 | + idx = indices_from_subtensor(idx_elements, op.idx_list) |
| 186 | + else: |
| 187 | + # For advanced indexing we can use the idx_elements directly |
| 188 | + idx = tuple(idx_elements) |
| 189 | + |
| 190 | + x_logp = _logprob_helper(x, value) |
| 191 | + y_logp = _logprob_helper(y, value[idx]) |
| 192 | + |
| 193 | + y_ndim_supp = x[idx].type.ndim - y_logp.type.ndim |
| 194 | + x_ndim_supp = x.type.ndim - x_logp.type.ndim |
| 195 | + ndim_supp = max(y_ndim_supp, x_ndim_supp) |
| 196 | + if ndim_supp > 0: |
| 197 | + # Multivariate logp only valid if we are not doing indexing along the reduced dimensions |
| 198 | + # Otherwise we don't know if successive writings are overlapping or not |
| 199 | + core_dims = set(range(x.type.ndim)[-ndim_supp:]) |
| 200 | + if set(indexed_dims(idx)) & core_dims: |
| 201 | + # When we have IR meta-info about support_ndim, we can fail at the rewriting stage |
| 202 | + raise NotImplementedError( |
| 203 | + "Indexing along core dimensions of multivariate SetSubtensor not supported" |
| 204 | + ) |
| 205 | + |
| 206 | + ndim_supp_diff = y_ndim_supp - x_ndim_supp |
| 207 | + if ndim_supp_diff > 0: |
| 208 | + # In this case y_logp will have fewer dimensions than x_logp after indexing, so we need to reduce x before indexing. |
| 209 | + x_logp = x_logp.sum(axis=tuple(range(-ndim_supp_diff, 0))) |
| 210 | + elif ndim_supp_diff < 0: |
| 211 | + # In this case x_logp will have fewer dimensions than y_logp after indexing, so we need to reduce y before indexing. |
| 212 | + y_logp = y_logp.sum(axis=tuple(range(ndim_supp_diff, 0))) |
| 213 | + |
| 214 | + out_logp = x_logp[idx].set(y_logp) |
| 215 | + return out_logp |
0 commit comments