Skip to content

Commit 24846fd

Browse files
committed
Try to run full test suite in Numba backend
1 parent 21f1895 commit 24846fd

File tree

3 files changed

+36
-16
lines changed

3 files changed

+36
-16
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ jobs:
143143
shell: micromamba-shell {0}
144144
run: |
145145
micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy
146-
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi
146+
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"
147147
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi
148148
pip install -e ./
149149
micromamba list && pip freeze

pytensor/compile/mode.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,8 @@ def register_linker(name, linker):
6161
# If a string is passed as the optimizer argument in the constructor
6262
# for Mode, it will be used as the key to retrieve the real optimizer
6363
# in this dictionary
64-
exclude = []
65-
if not config.cxx:
66-
exclude = ["cxx_only"]
64+
65+
exclude = ["cxx_only", "BlasOpt"]
6766
OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude)
6867
# Even if multiple merge optimizer call will be there, this shouldn't
6968
# impact performance.
@@ -340,6 +339,11 @@ def __setstate__(self, state):
340339
optimizer = predefined_optimizers[optimizer]
341340
if isinstance(optimizer, RewriteDatabaseQuery):
342341
self.provided_optimizer = optimizer
342+
343+
# Force numba-required rewrites if using NumbaLinker
344+
if isinstance(linker, NumbaLinker):
345+
optimizer = optimizer.including("numba")
346+
343347
self._optimizer = optimizer
344348
self.call_time = 0
345349
self.fn_time = 0
@@ -437,16 +441,20 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
437441
# string as the key
438442
# Use VM_linker to allow lazy evaluation by default.
439443
FAST_COMPILE = Mode(
440-
VMLinker(use_cloop=False, c_thunks=False),
441-
RewriteDatabaseQuery(include=["fast_compile", "py_only"]),
444+
NumbaLinker(),
445+
# TODO: Fast_compile should just use python code, CHANGE ME!
446+
RewriteDatabaseQuery(
447+
include=["fast_compile", "numba"],
448+
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
449+
),
450+
)
451+
FAST_RUN = Mode(
452+
NumbaLinker(),
453+
RewriteDatabaseQuery(
454+
include=["fast_run", "numba"],
455+
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
456+
),
442457
)
443-
if config.cxx:
444-
FAST_RUN = Mode("cvm", "fast_run")
445-
else:
446-
FAST_RUN = Mode(
447-
"vm",
448-
RewriteDatabaseQuery(include=["fast_run", "py_only"]),
449-
)
450458

451459
JAX = Mode(
452460
JAXLinker(),
@@ -512,7 +520,7 @@ def get_mode(orig_string):
512520
# NanGuardMode use its own linker.
513521
ret = NanGuardMode(True, True, True, optimizer=config.optimizer)
514522
else:
515-
# TODO: Can't we look up the name and invoke it rather than using eval here?
523+
# TODO: Get rid of this? Or refactor?
516524
ret = eval(string + "(linker=config.linker, optimizer=config.optimizer)")
517525
elif string in predefined_modes:
518526
ret = predefined_modes[string]
@@ -541,6 +549,7 @@ def register_mode(name, mode):
541549
Add a `Mode` which can be referred to by `name` in `function`.
542550
543551
"""
552+
# TODO: Remove me
544553
if name in predefined_modes:
545554
raise ValueError(f"Mode name already taken: {name}")
546555
predefined_modes[name] = mode

pytensor/configdefaults.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,18 @@ def add_compile_configvars():
461461
"linker",
462462
"Default linker used if the pytensor flags mode is Mode",
463463
EnumStr(
464-
"cvm", ["c|py", "py", "c", "c|py_nogc", "vm", "vm_nogc", "cvm_nogc"]
464+
"numba",
465+
[
466+
"cvm",
467+
"jax",
468+
"c|py",
469+
"py",
470+
"c",
471+
"c|py_nogc",
472+
"vm",
473+
"vm_nogc",
474+
"cvm_nogc",
475+
],
465476
),
466477
in_c_key=False,
467478
)
@@ -471,7 +482,7 @@ def add_compile_configvars():
471482
config.add(
472483
"linker",
473484
"Default linker used if the pytensor flags mode is Mode",
474-
EnumStr("vm", ["py", "vm_nogc"]),
485+
EnumStr("numba", ["vm", "jax", "py", "vm_nogc"]),
475486
in_c_key=False,
476487
)
477488
if type(config).cxx.is_default:

0 commit comments

Comments
 (0)