From 66f61d09b5e7c7ba87ab7b481f1722f51ba66015 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Tue, 1 Apr 2025 01:58:39 +0200 Subject: [PATCH 1/3] Transform to remove Minibatch from model --- pymc/model/transform/basic.py | 17 ++++++++++++++++- tests/model/transform/test_basic.py | 18 +++++++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/pymc/model/transform/basic.py b/pymc/model/transform/basic.py index fcf42fdf8c..45306ff144 100644 --- a/pymc/model/transform/basic.py +++ b/pymc/model/transform/basic.py @@ -14,8 +14,10 @@ from collections.abc import Sequence from pytensor import Variable -from pytensor.graph import ancestors +from pytensor.graph import ancestors, node_rewriter +from pytensor.graph.rewriting.basic import out2in +from pymc.data import MinibatchOp from pymc.model.core import Model from pymc.model.fgraph import ( ModelObservedRV, @@ -58,3 +60,16 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l else: vars_seq = (vars,) return [model[var] if isinstance(var, str) else var for var in vars_seq] + + +def remove_minibatched_nodes(model: Model): + """Remove all uses of pm.Minibatch in the Model.""" + + @node_rewriter([MinibatchOp]) + def local_remove_minibatch(fgraph, node): + return node.inputs + + remove_minibatch = out2in(local_remove_minibatch) + fgraph, _ = fgraph_from_model(model) + remove_minibatch.apply(fgraph) + return model_from_fgraph(fgraph, mutate_fgraph=True) diff --git a/tests/model/transform/test_basic.py b/tests/model/transform/test_basic.py index 25bf2324ec..16e0c4ac1f 100644 --- a/tests/model/transform/test_basic.py +++ b/tests/model/transform/test_basic.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np + import pymc as pm -from pymc.model.transform.basic import prune_vars_detached_from_observed +from pymc.model.transform.basic import prune_vars_detached_from_observed, remove_minibatched_nodes def test_prune_vars_detached_from_observed(): @@ -30,3 +32,17 @@ def test_prune_vars_detached_from_observed(): assert set(m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs", "d0", "d1"} pruned_m = prune_vars_detached_from_observed(m) assert set(pruned_m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs"} + + +def test_remove_minibatches(): + data_size = 100 + data = np.zeros((data_size,)) + batch_size = 10 + with pm.Model() as m1: + mb = pm.Minibatch(data, batch_size=batch_size) + x = pm.Normal("x") + y = pm.Normal("y", x, observed=mb, total_size=100) + + m2 = remove_minibatched_nodes(m1) + assert m1.y.shape[0].eval() == batch_size + assert m2.y.shape[0].eval() == data_size From beccca42b69c3d3ad1760d691d25dace86aed188 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sat, 5 Apr 2025 23:40:47 +0200 Subject: [PATCH 2/3] Rework transform --- pymc/model/transform/basic.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/pymc/model/transform/basic.py b/pymc/model/transform/basic.py index 45306ff144..877814cd61 100644 --- a/pymc/model/transform/basic.py +++ b/pymc/model/transform/basic.py @@ -13,9 +13,9 @@ # limitations under the License. from collections.abc import Sequence -from pytensor import Variable -from pytensor.graph import ancestors, node_rewriter -from pytensor.graph.rewriting.basic import out2in +from pytensor import Variable, clone_replace +from pytensor.graph import ancestors +from pytensor.graph.fg import FunctionGraph from pymc.data import MinibatchOp from pymc.model.core import Model @@ -62,14 +62,23 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l return [model[var] if isinstance(var, str) else var for var in vars_seq] -def remove_minibatched_nodes(model: Model): +def remove_minibatched_nodes(model: Model) -> Model: """Remove all uses of pm.Minibatch in the Model.""" + fgraph, _ = fgraph_from_model(model) - @node_rewriter([MinibatchOp]) - def local_remove_minibatch(fgraph, node): - return node.inputs + replacements = {} + for var in fgraph.apply_nodes: + if isinstance(var.op, MinibatchOp): + for inp, out in zip(var.inputs, var.outputs): + replacements[out] = inp - remove_minibatch = out2in(local_remove_minibatch) - fgraph, _ = fgraph_from_model(model) - remove_minibatch.apply(fgraph) + old_outs, old_coords, old_dim_lengths = fgraph.outputs, fgraph._coords, fgraph._dim_lengths # type: ignore[attr-defined] + # Using `rebuild_strict=False` means all coords, names, and dim information is lost + # So we need to restore it from the old fgraph + new_outs = clone_replace(old_outs, replacements, rebuild_strict=False) # type: ignore[arg-type] + for old_out, new_out in zip(old_outs, new_outs): + new_out.name = old_out.name + fgraph = FunctionGraph(outputs=new_outs, clone=False) + fgraph._coords = old_coords # type: ignore[attr-defined] + fgraph._dim_lengths = old_dim_lengths # type: ignore[attr-defined] return model_from_fgraph(fgraph, mutate_fgraph=True) From 8bdd5b7816ab3e89c417fe5075d7e1b226bed327 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sun, 20 Apr 2025 18:35:27 +0200 Subject: [PATCH 3/3] Update test to check that coords and dim_lengths are preserved --- tests/model/transform/test_basic.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/model/transform/test_basic.py b/tests/model/transform/test_basic.py index 16e0c4ac1f..856fbf0b2b 100644 --- a/tests/model/transform/test_basic.py +++ b/tests/model/transform/test_basic.py @@ -38,11 +38,14 @@ def test_remove_minibatches(): data_size = 100 data = np.zeros((data_size,)) batch_size = 10 - with pm.Model() as m1: + with pm.Model(coords={"d": range(5)}) as m1: mb = pm.Minibatch(data, batch_size=batch_size) + mu = pm.Normal("mu", dims="d") x = pm.Normal("x") y = pm.Normal("y", x, observed=mb, total_size=100) m2 = remove_minibatched_nodes(m1) assert m1.y.shape[0].eval() == batch_size assert m2.y.shape[0].eval() == data_size + assert m1.coords == m2.coords + assert m1.dim_lengths["d"].eval() == m2.dim_lengths["d"].eval()