Skip to content

Commit df2ffe4

Browse files
committed
Add mode for rewrites needed when executing function in Python mode
1 parent 4ed8767 commit df2ffe4

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

pytensor/compile/mode.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ def apply(self, fgraph):
262262

263263
# final pass just to make sure
264264
optdb.register("merge3", MergeOptimizer(), "fast_run", "merge", position=100)
265+
optdb.register("py_only", EquilibriumDB(), "fast_compile", position=100)
265266

266267
_tags: Union[Tuple[str, str], Tuple]
267268

@@ -439,11 +440,17 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
439440
# FunctionMaker, the Mode will be taken from this dictionary using the
440441
# string as the key
441442
# Use VM_linker to allow lazy evaluation by default.
442-
FAST_COMPILE = Mode(VMLinker(use_cloop=False, c_thunks=False), "fast_compile")
443+
FAST_COMPILE = Mode(
444+
VMLinker(use_cloop=False, c_thunks=False),
445+
RewriteDatabaseQuery(include=["fast_compile", "py_only"]),
446+
)
443447
if config.cxx:
444448
FAST_RUN = Mode("cvm", "fast_run")
445449
else:
446-
FAST_RUN = Mode("vm", "fast_run")
450+
FAST_RUN = Mode(
451+
"vm",
452+
RewriteDatabaseQuery(include=["fast_run", "py_only"]),
453+
)
447454

448455
JAX = Mode(
449456
JAXLinker(),

0 commit comments

Comments
 (0)