Skip to content

Run whole test suite on numba backend #811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ jobs:
else
micromamba install --yes -q "python~=${PYTHON_VERSION}" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock;
fi
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
pip install pytest-sphinx
Expand Down
33 changes: 21 additions & 12 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,8 @@ def register_linker(name, linker):
# If a string is passed as the optimizer argument in the constructor
# for Mode, it will be used as the key to retrieve the real optimizer
# in this dictionary
exclude = []
if not config.cxx:
exclude = ["cxx_only"]

exclude = ["cxx_only", "BlasOpt"]
OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude)
# Even if multiple merge optimizer call will be there, this shouldn't
# impact performance.
Expand Down Expand Up @@ -346,6 +345,11 @@ def __setstate__(self, state):
optimizer = predefined_optimizers[optimizer]
if isinstance(optimizer, RewriteDatabaseQuery):
self.provided_optimizer = optimizer

# Force numba-required rewrites if using NumbaLinker
if isinstance(linker, NumbaLinker):
optimizer = optimizer.including("numba")

self._optimizer = optimizer
self.call_time = 0
self.fn_time = 0
Expand Down Expand Up @@ -443,16 +447,20 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
# string as the key
# Use VM_linker to allow lazy evaluation by default.
FAST_COMPILE = Mode(
VMLinker(use_cloop=False, c_thunks=False),
RewriteDatabaseQuery(include=["fast_compile", "py_only"]),
NumbaLinker(),
# TODO: Fast_compile should just use python code, CHANGE ME!
RewriteDatabaseQuery(
include=["fast_compile", "numba"],
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
),
)
FAST_RUN = Mode(
NumbaLinker(),
RewriteDatabaseQuery(
include=["fast_run", "numba"],
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
),
)
if config.cxx:
FAST_RUN = Mode("cvm", "fast_run")
else:
FAST_RUN = Mode(
"vm",
RewriteDatabaseQuery(include=["fast_run", "py_only"]),
)

NUMBA = Mode(
NumbaLinker(),
Expand Down Expand Up @@ -565,6 +573,7 @@ def register_mode(name, mode):
Add a `Mode` which can be referred to by `name` in `function`.

"""
# TODO: Remove me
if name in predefined_modes:
raise ValueError(f"Mode name already taken: {name}")
predefined_modes[name] = mode
Expand Down
16 changes: 13 additions & 3 deletions pytensor/configdefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,11 +370,21 @@ def add_compile_configvars():

if rc == 0 and config.cxx != "":
# Keep the default linker the same as the one for the mode FAST_RUN
linker_options = ["c|py", "py", "c", "c|py_nogc", "vm", "vm_nogc", "cvm_nogc"]
linker_options = [
"cvm",
"c|py",
"py",
"c",
"c|py_nogc",
"vm",
"vm_nogc",
"cvm_nogc",
"jax",
]
else:
# g++ is not present or the user disabled it,
# linker should default to python only.
linker_options = ["py", "vm_nogc"]
linker_options = ["py", "vm", "vm_nogc", "jax"]
if type(config).cxx.is_default:
# If the user provided an empty value for cxx, do not warn.
_logger.warning(
Expand All @@ -388,7 +398,7 @@ def add_compile_configvars():
"linker",
"Default linker used if the pytensor flags mode is Mode",
# Not mutable because the default mode is cached after the first use.
EnumStr("cvm", linker_options, mutable=False),
EnumStr("numba", linker_options, mutable=False),
in_c_key=False,
)

Expand Down
3 changes: 3 additions & 0 deletions pytensor/link/jax/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,6 @@ def create_thunk_inputs(self, storage_map):
thunk_inputs.append(sinput)

return thunk_inputs

def __repr__(self):
return "JAXLinker()"
9 changes: 6 additions & 3 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from pytensor.tensor.slinalg import Solve
from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import MakeSlice, NoneConst
from pytensor.typed_list import TypedListType


def global_numba_func(func):
Expand Down Expand Up @@ -135,6 +136,8 @@ def get_numba_type(
return CSCMatrixType(numba_dtype)

raise NotImplementedError()
elif isinstance(pytensor_type, TypedListType):
return numba.types.List(get_numba_type(pytensor_type.ttype))
else:
raise NotImplementedError(f"Numba type not implemented for {pytensor_type}")

Expand Down Expand Up @@ -481,11 +484,11 @@ def numba_funcify_SpecifyShape(op, node, **kwargs):
shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))]

func_conditions = [
f"assert x.shape[{i}] == {shape_input_names}"
for i, (shape_input, shape_input_names) in enumerate(
f"assert x.shape[{i}] == {eval_dim_name}, f'SpecifyShape: dim {{{i}}} of input has shape {{x.shape[{i}]}}, expected {{{eval_dim_name}.item()}}.'"
for i, (node_dim_input, eval_dim_name) in enumerate(
zip(shape_inputs, shape_input_names, strict=True)
)
if shape_input is not NoneConst
if node_dim_input is not NoneConst
]

func = dedent(
Expand Down
10 changes: 9 additions & 1 deletion pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,15 @@ def numba_funcify_CAReduce(op, node, **kwargs):


@numba_funcify.register(DimShuffle)
def numba_funcify_DimShuffle(op, node, **kwargs):
def numba_funcify_DimShuffle(op: DimShuffle, node, **kwargs):
if op.is_left_expand_dims and op.new_order.count("x") == 1:
# Most common case, numba compiles it more quickly
@numba_njit
def left_expand_dims(x):
return np.expand_dims(x, 0)

return left_expand_dims

# We use `as_strided` to achieve the DimShuffle behavior of transposing and expanding/squezing dimensions in one call
# Numba doesn't currently support multiple expand/squeeze, and reshape is limited to contiguous arrays.
new_order = tuple(op._new_order)
Expand Down
68 changes: 65 additions & 3 deletions pytensor/link/numba/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Composite,
Identity,
Mul,
Pow,
Reciprocal,
ScalarOp,
Second,
Expand Down Expand Up @@ -154,6 +155,21 @@ def numba_funcify_Switch(op, node, **kwargs):
return numba_basic.global_numba_func(switch)


@numba_funcify.register(Pow)
def numba_funcify_Pow(op, node, **kwargs):
pow_dtype = node.inputs[1].type.dtype

def pow(x, y):
return x**y

# Work-around https://github.com/numba/numba/issues/9554
# fast-math casuse kernel crash
patch_kwargs = {}
if pow_dtype.startswith("int"):
patch_kwargs["fastmath"] = False
return numba_basic.numba_njit(**patch_kwargs)(pow)


def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op: str):
"""Create a Numba-compatible N-ary function from a binary function."""
unique_names = unique_name_generator(["binary_op_name"], suffix_sep="_")
Expand All @@ -172,18 +188,64 @@ def {binary_op_name}({input_signature}):

@numba_funcify.register(Add)
def numba_funcify_Add(op, node, **kwargs):
match len(node.inputs):
case 2:

def add(i0, i1):
return i0 + i1
case 3:

def add(i0, i1, i2):
return i0 + i1 + i2
case 4:

def add(i0, i1, i2, i3):
return i0 + i1 + i2 + i3
case 5:

def add(i0, i1, i2, i3, i4):
return i0 + i1 + i2 + i3 + i4
case _:
add = None

if add is not None:
return numba_basic.numba_njit(add)

signature = create_numba_signature(node, force_scalar=True)
nary_add_fn = binary_to_nary_func(node.inputs, "add", "+")

return numba_basic.numba_njit(signature)(nary_add_fn)
return numba_basic.numba_njit(signature, cache=False)(nary_add_fn)


@numba_funcify.register(Mul)
def numba_funcify_Mul(op, node, **kwargs):
match len(node.inputs):
case 2:

def mul(i0, i1):
return i0 * i1
case 3:

def mul(i0, i1, i2):
return i0 * i1 * i2
case 4:

def mul(i0, i1, i2, i3):
return i0 * i1 * i2 * i3
case 5:

def mul(i0, i1, i2, i3, i4):
return i0 * i1 * i2 * i3 * i4
case _:
mul = None

if mul is not None:
return numba_basic.numba_njit(mul)

signature = create_numba_signature(node, force_scalar=True)
nary_add_fn = binary_to_nary_func(node.inputs, "mul", "*")

return numba_basic.numba_njit(signature)(nary_add_fn)
return numba_basic.numba_njit(signature, cache=False)(nary_add_fn)


@numba_funcify.register(Cast)
Expand Down Expand Up @@ -233,7 +295,7 @@ def numba_funcify_Composite(op, node, **kwargs):

_ = kwargs.pop("storage_map", None)

composite_fn = numba_basic.numba_njit(signature)(
composite_fn = numba_basic.numba_njit(signature, cache=False)(
numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
)
return composite_fn
Expand Down
Loading
Loading