13
13
# limitations under the License.
14
14
from collections .abc import Sequence
15
15
16
- from pytensor import Variable
17
- from pytensor .graph import ancestors , node_rewriter
18
- from pytensor .graph .rewriting . basic import out2in
16
+ from pytensor import Variable , clone_replace
17
+ from pytensor .graph import ancestors
18
+ from pytensor .graph .fg import FunctionGraph
19
19
20
20
from pymc .data import MinibatchOp
21
21
from pymc .model .core import Model
@@ -62,7 +62,7 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l
62
62
return [model [var ] if isinstance (var , str ) else var for var in vars_seq ]
63
63
64
64
65
- def remove_minibatched_nodes (model : pm . Model ) -> pm . Model :
65
+ def remove_minibatched_nodes (model : Model ) -> Model :
66
66
"""Remove all uses of pm.Minibatch in the Model."""
67
67
fgraph , _ = fgraph_from_model (model )
68
68
@@ -75,11 +75,10 @@ def remove_minibatched_nodes(model: pm.Model) -> pm.Model:
75
75
old_outs , old_coords , old_dim_lengths = fgraph .outputs , fgraph ._coords , fgraph ._dim_lengths
76
76
# Using `rebuild_strict=False` means all coords, names, and dim information is lost
77
77
# So we need to restore it from the old fgraph
78
- new_outs = pytensor . clone_replace (old_outs , replacements , rebuild_strict = False )
78
+ new_outs = clone_replace (old_outs , replacements , rebuild_strict = False )
79
79
for old_out , new_out in zip (old_outs , new_outs ):
80
80
new_out .name = old_out .name
81
- fgraph = pytensor . graph . fg . FunctionGraph (outputs = new_outs , clone = False )
81
+ fgraph = FunctionGraph (outputs = new_outs , clone = False )
82
82
fgraph ._coords = old_coords
83
83
fgraph ._dim_lengths = old_dim_lengths
84
84
return model_from_fgraph (fgraph , mutate_fgraph = True )
85
-
0 commit comments