Skip to content

Commit 9836d00

Browse files
committed
Refactor utility to ignore the logprob of multiple variables while keeping their interdependencies intact
1 parent 1cc9863 commit 9836d00

File tree

2 files changed

+43
-22
lines changed

2 files changed

+43
-22
lines changed

pymc/logprob/tensor.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252

5353
from pymc.logprob.abstract import MeasurableVariable, _logprob, logprob
5454
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
55-
from pymc.logprob.utils import ignore_logprob
55+
from pymc.logprob.utils import ignore_logprob, ignore_logprob_multiple_vars
5656

5757

5858
@node_rewriter([BroadcastTo])
@@ -228,25 +228,7 @@ def find_measurable_stacks(
228228
):
229229
return None # pragma: no cover
230230

231-
# Make base_vars unmeasurable
232-
base_to_unmeasurable_vars = {base_var: ignore_logprob(base_var) for base_var in base_vars}
233-
234-
def replacement_fn(var, replacements):
235-
if var in base_to_unmeasurable_vars:
236-
replacements[var] = base_to_unmeasurable_vars[var]
237-
# We don't want to clone valued nodes. Assigning a var to itself in the
238-
# replacements prevents this
239-
elif var in rvs_to_values:
240-
replacements[var] = var
241-
242-
return []
243-
244-
# TODO: Fix this import circularity!
245-
from pymc.pytensorf import _replace_rvs_in_graphs
246-
247-
unmeasurable_base_vars, _ = _replace_rvs_in_graphs(
248-
graphs=base_vars, replacement_fn=replacement_fn
249-
)
231+
unmeasurable_base_vars = ignore_logprob_multiple_vars(base_vars, rvs_to_values)
250232

251233
if is_join:
252234
measurable_stack = MeasurableJoin()(axis, *unmeasurable_base_vars)

pymc/logprob/utils.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,17 @@
3737
import warnings
3838

3939
from copy import copy
40-
from typing import Callable, Dict, Generator, Iterable, List, Optional, Set, Tuple
40+
from typing import (
41+
Callable,
42+
Dict,
43+
Generator,
44+
Iterable,
45+
List,
46+
Optional,
47+
Sequence,
48+
Set,
49+
Tuple,
50+
)
4151

4252
import numpy as np
4353

@@ -265,7 +275,7 @@ def diracdelta_logprob(op, values, *inputs, **kwargs):
265275
def ignore_logprob(rv: TensorVariable) -> TensorVariable:
266276
"""Return a duplicated variable that is ignored when creating logprob graphs
267277
268-
This is used in SymbolicDistributions that use other RVs as inputs but account
278+
This is used in by MeasurableRVs that use other RVs as inputs but account
269279
for their logp terms explicitly.
270280
271281
If the variable is already ignored, it is returned directly.
@@ -298,3 +308,32 @@ def reconsider_logprob(rv: TensorVariable) -> TensorVariable:
298308
new_node.op = copy(new_node.op)
299309
new_node.op.__class__ = original_op_type
300310
return new_node.outputs[node.outputs.index(rv)]
311+
312+
313+
def ignore_logprob_multiple_vars(
314+
vars: Sequence[TensorVariable], rvs_to_values: Dict[TensorVariable, TensorVariable]
315+
) -> List[TensorVariable]:
316+
"""Return duplicated variables that are ignored when creating logprob graphs.
317+
318+
This function keeps any interdependencies between variables intact, after
319+
making each "unmeasurable", whereas a sequential call to `ignore_logprob`
320+
would not do this correctly.
321+
"""
322+
from pymc.pytensorf import _replace_rvs_in_graphs
323+
324+
measurable_vars_to_unmeasurable_vars = {
325+
measurable_var: ignore_logprob(measurable_var) for measurable_var in vars
326+
}
327+
328+
def replacement_fn(var, replacements):
329+
if var in measurable_vars_to_unmeasurable_vars:
330+
replacements[var] = measurable_vars_to_unmeasurable_vars[var]
331+
# We don't want to clone valued nodes. Assigning a var to itself in the
332+
# replacements prevents this
333+
elif var in rvs_to_values:
334+
replacements[var] = var
335+
336+
return []
337+
338+
unmeasurable_vars, _ = _replace_rvs_in_graphs(graphs=vars, replacement_fn=replacement_fn)
339+
return unmeasurable_vars

0 commit comments

Comments
 (0)