Skip to content

Commit 98bd344

Browse files
committed
Rework transform
1 parent 286df98 commit 98bd344

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

pymc/model/transform/basic.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,24 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l
6262
return [model[var] if isinstance(var, str) else var for var in vars_seq]
6363

6464

65-
def remove_minibatched_nodes(model: Model):
65+
def remove_minibatched_nodes(model: pm.Model) -> pm.Model:
6666
"""Remove all uses of pm.Minibatch in the Model."""
67+
fgraph, _ = fgraph_from_model(model)
6768

68-
@node_rewriter([MinibatchOp])
69-
def local_remove_minibatch(fgraph, node):
70-
return node.inputs
69+
replacements = {}
70+
for var in fgraph.apply_nodes:
71+
if isinstance(var.op, MinibatchOp):
72+
for inp, out in zip(var.inputs, var.outputs):
73+
replacements[out] = inp
7174

72-
remove_minibatch = out2in(local_remove_minibatch)
73-
fgraph, _ = fgraph_from_model(model)
74-
remove_minibatch.apply(fgraph)
75+
old_outs, old_coords, old_dim_lengths = fgraph.outputs, fgraph._coords, fgraph._dim_lengths
76+
# Using `rebuild_strict=False` means all coords, names, and dim information is lost
77+
# So we need to restore it from the old fgraph
78+
new_outs = pytensor.clone_replace(old_outs, replacements, rebuild_strict=False)
79+
for old_out, new_out in zip(old_outs, new_outs):
80+
new_out.name = old_out.name
81+
fgraph = pytensor.graph.fg.FunctionGraph(outputs=new_outs, clone=False)
82+
fgraph._coords = old_coords
83+
fgraph._dim_lengths = old_dim_lengths
7584
return model_from_fgraph(fgraph, mutate_fgraph=True)
85+

0 commit comments

Comments
 (0)