@@ -62,14 +62,24 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l
62
62
return [model [var ] if isinstance (var , str ) else var for var in vars_seq ]
63
63
64
64
65
- def remove_minibatched_nodes (model : Model ):
65
+ def remove_minibatched_nodes (model : pm . Model ) -> pm . Model :
66
66
"""Remove all uses of pm.Minibatch in the Model."""
67
+ fgraph , _ = fgraph_from_model (model )
67
68
68
- @node_rewriter ([MinibatchOp ])
69
- def local_remove_minibatch (fgraph , node ):
70
- return node .inputs
69
+ replacements = {}
70
+ for var in fgraph .apply_nodes :
71
+ if isinstance (var .op , MinibatchOp ):
72
+ for inp , out in zip (var .inputs , var .outputs ):
73
+ replacements [out ] = inp
71
74
72
- remove_minibatch = out2in (local_remove_minibatch )
73
- fgraph , _ = fgraph_from_model (model )
74
- remove_minibatch .apply (fgraph )
75
+ old_outs , old_coords , old_dim_lengths = fgraph .outputs , fgraph ._coords , fgraph ._dim_lengths
76
+ # Using `rebuild_strict=False` means all coords, names, and dim information is lost
77
+ # So we need to restore it from the old fgraph
78
+ new_outs = pytensor .clone_replace (old_outs , replacements , rebuild_strict = False )
79
+ for old_out , new_out in zip (old_outs , new_outs ):
80
+ new_out .name = old_out .name
81
+ fgraph = pytensor .graph .fg .FunctionGraph (outputs = new_outs , clone = False )
82
+ fgraph ._coords = old_coords
83
+ fgraph ._dim_lengths = old_dim_lengths
75
84
return model_from_fgraph (fgraph , mutate_fgraph = True )
85
+
0 commit comments