|
37 | 37 | import warnings
|
38 | 38 |
|
39 | 39 | from copy import copy
|
40 |
| -from typing import Callable, Dict, Generator, Iterable, List, Optional, Set, Tuple |
| 40 | +from typing import ( |
| 41 | + Callable, |
| 42 | + Dict, |
| 43 | + Generator, |
| 44 | + Iterable, |
| 45 | + List, |
| 46 | + Optional, |
| 47 | + Sequence, |
| 48 | + Set, |
| 49 | + Tuple, |
| 50 | +) |
41 | 51 |
|
42 | 52 | import numpy as np
|
43 | 53 |
|
@@ -265,7 +275,7 @@ def diracdelta_logprob(op, values, *inputs, **kwargs):
|
265 | 275 | def ignore_logprob(rv: TensorVariable) -> TensorVariable:
|
266 | 276 | """Return a duplicated variable that is ignored when creating logprob graphs
|
267 | 277 |
|
268 |
| - This is used in SymbolicDistributions that use other RVs as inputs but account |
| 278 | + This is used in by MeasurableRVs that use other RVs as inputs but account |
269 | 279 | for their logp terms explicitly.
|
270 | 280 |
|
271 | 281 | If the variable is already ignored, it is returned directly.
|
@@ -298,3 +308,32 @@ def reconsider_logprob(rv: TensorVariable) -> TensorVariable:
|
298 | 308 | new_node.op = copy(new_node.op)
|
299 | 309 | new_node.op.__class__ = original_op_type
|
300 | 310 | return new_node.outputs[node.outputs.index(rv)]
|
| 311 | + |
| 312 | + |
| 313 | +def ignore_logprob_multiple_vars( |
| 314 | + vars: Sequence[TensorVariable], rvs_to_values: Dict[TensorVariable, TensorVariable] |
| 315 | +) -> List[TensorVariable]: |
| 316 | + """Return duplicated variables that are ignored when creating logprob graphs. |
| 317 | +
|
| 318 | + This function keeps any interdependencies between variables intact, after |
| 319 | + making each "unmeasurable", whereas a sequential call to `ignore_logprob` |
| 320 | + would not do this correctly. |
| 321 | + """ |
| 322 | + from pymc.pytensorf import _replace_rvs_in_graphs |
| 323 | + |
| 324 | + measurable_vars_to_unmeasurable_vars = { |
| 325 | + measurable_var: ignore_logprob(measurable_var) for measurable_var in vars |
| 326 | + } |
| 327 | + |
| 328 | + def replacement_fn(var, replacements): |
| 329 | + if var in measurable_vars_to_unmeasurable_vars: |
| 330 | + replacements[var] = measurable_vars_to_unmeasurable_vars[var] |
| 331 | + # We don't want to clone valued nodes. Assigning a var to itself in the |
| 332 | + # replacements prevents this |
| 333 | + elif var in rvs_to_values: |
| 334 | + replacements[var] = var |
| 335 | + |
| 336 | + return [] |
| 337 | + |
| 338 | + unmeasurable_vars, _ = _replace_rvs_in_graphs(graphs=vars, replacement_fn=replacement_fn) |
| 339 | + return unmeasurable_vars |
0 commit comments