Skip to content

Commit db15ae4

Browse files
committed
Avoid cloning of Minibatch values
1 parent cafb60b commit db15ae4

File tree

4 files changed

+60
-8
lines changed

4 files changed

+60
-8
lines changed

pymc/data.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
import pymc as pm
3838

39+
from pymc.logprob.abstract import _get_measurable_outputs
3940
from pymc.pytensorf import convert_observed_data
4041

4142
__all__ = [
@@ -134,6 +135,11 @@ def make_node(self, rng, *args, **kwargs):
134135
return super().make_node(rng, *args, **kwargs)
135136

136137

138+
@_get_measurable_outputs.register(MinibatchIndexRV)
139+
def minibatch_index_rv_measuarable_outputs(op, node):
140+
return []
141+
142+
137143
minibatch_index = MinibatchIndexRV()
138144

139145

pymc/logprob/basic.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,13 @@
4444
import pytensor.tensor as pt
4545

4646
from pytensor import config
47-
from pytensor.graph.basic import Variable, graph_inputs, io_toposort
47+
from pytensor.graph.basic import (
48+
Constant,
49+
Variable,
50+
ancestors,
51+
graph_inputs,
52+
io_toposort,
53+
)
4854
from pytensor.graph.op import compute_test_value
4955
from pytensor.graph.rewriting.basic import GraphRewriter, NodeRewriter
5056
from pytensor.tensor.random.op import RandomVariable
@@ -231,10 +237,16 @@ def factorized_joint_logprob(
231237
# node.
232238
replacements = updated_rv_values.copy()
233239

234-
# To avoid cloning the value variables, we map them to themselves in the
235-
# `replacements` `dict` (i.e. entries already existing in `replacements`
236-
# aren't cloned)
237-
replacements.update({v: v for v in rv_values.values()})
240+
# To avoid cloning the value variables (or ancestors of value variables),
241+
# we map them to themselves in the `replacements` `dict`
242+
# (i.e. entries already existing in `replacements` aren't cloned)
243+
replacements.update(
244+
{
245+
v: v
246+
for v in ancestors(rv_values.values())
247+
if (not isinstance(v, Constant) and v not in replacements)
248+
}
249+
)
238250

239251
# Walk the graph from its inputs to its outputs and construct the
240252
# log-probability

pymc/logprob/rewriting.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
import pytensor.tensor as pt
4040

4141
from pytensor.compile.mode import optdb
42-
from pytensor.graph.basic import Variable
42+
from pytensor.graph.basic import Constant, Variable, ancestors
4343
from pytensor.graph.features import Feature
4444
from pytensor.graph.fg import FunctionGraph
4545
from pytensor.graph.rewriting.basic import GraphRewriter, node_rewriter
@@ -316,8 +316,8 @@ def construct_ir_fgraph(
316316
# the old nodes to the new ones; otherwise, we won't be able to use
317317
# `rv_values`.
318318
# We start the `dict` with mappings from the value variables to themselves,
319-
# to prevent them from being cloned.
320-
memo = {v: v for v in rv_values.values()}
319+
# to prevent them from being cloned. This also includes ancestors
320+
memo = {v: v for v in ancestors(rv_values.values()) if not isinstance(v, Constant)}
321321

322322
# We add `ShapeFeature` because it will get rid of references to the old
323323
# `RandomVariable`s that have been lifted; otherwise, it will be difficult

tests/variational/test_minibatch_rv.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import pymc as pm
2121

2222
from pymc import Normal, draw
23+
from pymc.data import minibatch_index
2324
from pymc.testing import select_by_precision
2425
from pymc.variational.minibatch_rv import create_minibatch_rv
2526
from tests.test_data import gen1, gen2
@@ -155,3 +156,36 @@ def test_random(self):
155156
mx = create_minibatch_rv(x, total_size=(10,))
156157
assert mx is not x
157158
np.testing.assert_array_equal(draw(mx, random_seed=1), draw(x, random_seed=1))
159+
160+
@pytest.mark.filterwarnings("error")
161+
def test_minibatch_parameter_and_value(self):
162+
rng = np.random.default_rng(161)
163+
total_size = 1000
164+
165+
with pm.Model(check_bounds=False) as m:
166+
AD = pm.MutableData("AD", np.arange(total_size, dtype="float64"))
167+
TD = pm.MutableData("TD", np.arange(total_size, dtype="float64"))
168+
169+
minibatch_idx = minibatch_index(0, 10, size=(9,))
170+
AD_mt = AD[minibatch_idx]
171+
TD_mt = TD[minibatch_idx]
172+
173+
pm.Normal(
174+
"AD_predicted",
175+
mu=TD_mt,
176+
observed=AD_mt,
177+
total_size=1000,
178+
)
179+
180+
logp_fn = m.compile_logp()
181+
182+
ip = m.initial_point()
183+
np.testing.assert_allclose(logp_fn(ip), st.norm.logpdf(0) * 1000)
184+
185+
with m:
186+
pm.set_data({"AD": np.arange(total_size) + 1})
187+
np.testing.assert_allclose(logp_fn(ip), st.norm.logpdf(1) * 1000)
188+
189+
with m:
190+
pm.set_data({"AD": rng.normal(size=1000)})
191+
assert logp_fn(ip) != logp_fn(ip)

0 commit comments

Comments
 (0)