diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index c081fbe9ef..fb63706139 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -1,4 +1,5 @@ import operator +import sys import warnings from contextlib import contextmanager from functools import singledispatch @@ -10,7 +11,7 @@ import numpy as np import scipy import scipy.special -from llvmlite.llvmpy.core import Type as llvm_Type +from llvmlite import ir from numba import types from numba.core.errors import TypingError from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 @@ -48,10 +49,13 @@ def numba_njit(*args, **kwargs): + kwargs = kwargs.copy() + kwargs.setdefault("cache", config.numba__cache) + if len(args) > 0 and callable(args[0]): - return numba.njit(*args[1:], cache=config.numba__cache, **kwargs)(args[0]) + return numba.njit(*args[1:], **kwargs)(args[0]) - return numba.njit(*args, cache=config.numba__cache, **kwargs) + return numba.njit(*args, **kwargs) def numba_vectorize(*args, **kwargs): @@ -128,7 +132,7 @@ def create_numba_signature( def slice_new(self, start, stop, step): - fnty = llvm_Type.function(self.pyobj, [self.pyobj, self.pyobj, self.pyobj]) + fnty = ir.FunctionType(self.pyobj, [self.pyobj, self.pyobj, self.pyobj]) fn = self._get_function(fnty, name="PySlice_New") return self.builder.call(fn, [start, stop, step]) @@ -147,11 +151,33 @@ def box_slice(typ, val, c): This makes it possible to return an Numba's internal representation of a ``slice`` object as a proper ``slice`` to Python. """ + start = c.builder.extract_value(val, 0) + stop = c.builder.extract_value(val, 1) + + none_val = ir.Constant(ir.IntType(64), sys.maxsize) + + start_is_none = c.builder.icmp_signed("==", start, none_val) + start = c.builder.select( + start_is_none, + c.pyapi.get_null_object(), + c.box(types.int64, start), + ) + + stop_is_none = c.builder.icmp_signed("==", stop, none_val) + stop = c.builder.select( + stop_is_none, + c.pyapi.get_null_object(), + c.box(types.int64, stop), + ) - start = c.box(types.int64, c.builder.extract_value(val, 0)) - stop = c.box(types.int64, c.builder.extract_value(val, 1)) if typ.has_step: - step = c.box(types.int64, c.builder.extract_value(val, 2)) + step = c.builder.extract_value(val, 2) + step_is_none = c.builder.icmp_signed("==", step, none_val) + step = c.builder.select( + step_is_none, + c.pyapi.get_null_object(), + c.box(types.int64, step), + ) else: step = c.pyapi.get_null_object() @@ -319,9 +345,8 @@ def numba_typify(data, dtype=None, **kwargs): return data -@singledispatch -def numba_funcify(op, node=None, storage_map=None, **kwargs): - """Create a Numba compatible function from an PyTensor `Op`.""" +def generate_fallback_impl(op, node=None, storage_map=None, **kwargs): + """Create a Numba compatible function from an Aesara `Op`.""" warnings.warn( f"Numba will use object mode to run {op}'s perform method", @@ -375,6 +400,12 @@ def perform(*inputs): return perform +@singledispatch +def numba_funcify(op, node=None, storage_map=None, **kwargs): + """Generate a numba function for a given op and apply node.""" + return generate_fallback_impl(op, node, storage_map, **kwargs) + + @numba_funcify.register(OpFromGraph) def numba_funcify_OpFromGraph(op, node=None, **kwargs): @@ -506,7 +537,6 @@ def {fn_name}({", ".join(input_names)}): @numba_funcify.register(Subtensor) -@numba_funcify.register(AdvancedSubtensor) @numba_funcify.register(AdvancedSubtensor1) def numba_funcify_Subtensor(op, node, **kwargs): @@ -524,7 +554,6 @@ def numba_funcify_Subtensor(op, node, **kwargs): @numba_funcify.register(IncSubtensor) -@numba_funcify.register(AdvancedIncSubtensor) def numba_funcify_IncSubtensor(op, node, **kwargs): incsubtensor_def_src = create_index_func( diff --git a/pytensor/link/numba/dispatch/cython_support.py b/pytensor/link/numba/dispatch/cython_support.py new file mode 100644 index 0000000000..ab551bb1d3 --- /dev/null +++ b/pytensor/link/numba/dispatch/cython_support.py @@ -0,0 +1,211 @@ +import ctypes +import importlib +import re +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, cast + +import numba +import numpy as np +from numpy.typing import DTypeLike +from scipy import LowLevelCallable + + +_C_TO_NUMPY: Dict[str, DTypeLike] = { + "bool": np.bool_, + "signed char": np.byte, + "unsigned char": np.ubyte, + "short": np.short, + "unsigned short": np.ushort, + "int": np.intc, + "unsigned int": np.uintc, + "long": np.int_, + "unsigned long": np.uint, + "long long": np.longlong, + "float": np.single, + "double": np.double, + "long double": np.longdouble, + "float complex": np.csingle, + "double complex": np.cdouble, +} + + +@dataclass +class Signature: + res_dtype: DTypeLike + res_c_type: str + arg_dtypes: List[DTypeLike] + arg_c_types: List[str] + arg_names: List[Optional[str]] + + @property + def arg_numba_types(self) -> List[DTypeLike]: + return [numba.from_dtype(dtype) for dtype in self.arg_dtypes] + + def can_cast_args(self, args: List[DTypeLike]) -> bool: + ok = True + count = 0 + for name, dtype in zip(self.arg_names, self.arg_dtypes): + if name == "__pyx_skip_dispatch": + continue + if len(args) <= count: + raise ValueError("Incorrect number of arguments") + ok &= np.can_cast(args[count], dtype) + count += 1 + if count != len(args): + return False + return ok + + def provides(self, restype: DTypeLike, arg_dtypes: List[DTypeLike]) -> bool: + args_ok = self.can_cast_args(arg_dtypes) + if np.issubdtype(restype, np.inexact): + result_ok = np.can_cast(self.res_dtype, restype, casting="same_kind") + # We do not want to provide less accuracy than advertised + result_ok &= np.dtype(self.res_dtype).itemsize >= np.dtype(restype).itemsize + else: + result_ok = np.can_cast(self.res_dtype, restype) + return args_ok and result_ok + + @staticmethod + def from_c_types(signature: bytes) -> "Signature": + # Match strings like "double(int, double)" + # and extract the return type and the joined arguments + expr = re.compile(rb"\s*(?P[\w ]*\w+)\s*\((?P[\w\s,]*)\)") + re_match = re.fullmatch(expr, signature) + + if re_match is None: + raise ValueError(f"Invalid signature: {signature.decode()}") + + groups = re_match.groupdict() + res_c_type = groups["restype"].decode() + res_dtype: DTypeLike = _C_TO_NUMPY[res_c_type] + + raw_args = groups["args"] + + decl_expr = re.compile( + rb"\s*(?P((long )|(unsigned )|(signed )|(double )|)" + rb"((double)|(float)|(int)|(short)|(char)|(long)|(bool)|(complex)))" + rb"(\s(?P[\w_]*))?\s*" + ) + + arg_dtypes = [] + arg_names: List[Optional[str]] = [] + arg_c_types = [] + for raw_arg in raw_args.split(b","): + re_match = re.fullmatch(decl_expr, raw_arg) + if re_match is None: + raise ValueError(f"Invalid signature: {signature.decode()}") + groups = re_match.groupdict() + arg_c_type = groups["type"].decode() + try: + arg_dtype = _C_TO_NUMPY[arg_c_type] + except KeyError: + raise ValueError(f"Unknown C type: {arg_c_type}") + + arg_c_types.append(arg_c_type) + arg_dtypes.append(arg_dtype) + name = groups["name"] + if not name: + arg_names.append(None) + else: + arg_names.append(name.decode()) + + return Signature(res_dtype, res_c_type, arg_dtypes, arg_c_types, arg_names) + + +def _available_impls(func: Callable) -> List[Tuple[Signature, Any]]: + """Find all available implementations for a fused cython function.""" + impls = [] + mod = importlib.import_module(func.__module__) + + signatures = getattr(func, "__signatures__", None) + if signatures is not None: + # Cython function with __signatures__ should be fused and thus + # indexable + func_map = cast(Mapping, func) + candidates = [func_map[key] for key in signatures] + else: + candidates = [func] + for candidate in candidates: + name = candidate.__name__ + capsule = mod.__pyx_capi__[name] + llc = LowLevelCallable(capsule) + try: + signature = Signature.from_c_types(llc.signature.encode()) + except KeyError: + continue + impls.append((signature, capsule)) + return impls + + +class _CythonWrapper(numba.types.WrapperAddressProtocol): + def __init__(self, pyfunc, signature, capsule): + self._keep_alive = capsule + get_name = ctypes.pythonapi.PyCapsule_GetName + get_name.restype = ctypes.c_char_p + get_name.argtypes = (ctypes.py_object,) + + raw_signature = get_name(capsule) + + get_pointer = ctypes.pythonapi.PyCapsule_GetPointer + get_pointer.restype = ctypes.c_void_p + get_pointer.argtypes = (ctypes.py_object, ctypes.c_char_p) + self._func_ptr = get_pointer(capsule, raw_signature) + + self._signature = signature + self._pyfunc = pyfunc + + def signature(self): + return numba.from_dtype(self._signature.res_dtype)( + *self._signature.arg_numba_types + ) + + def __wrapper_address__(self): + return self._func_ptr + + def __call__(self, *args, **kwargs): + args = [dtype(arg) for arg, dtype in zip(args, self._signature.arg_dtypes)] + if self.has_pyx_skip_dispatch(): + output = self._pyfunc(*args[:-1], **kwargs) + else: + output = self._pyfunc(*args, **kwargs) + return self._signature.res_dtype(output) + + def has_pyx_skip_dispatch(self): + if not self._signature.arg_names: + return False + if any( + name == "__pyx_skip_dispatch" for name in self._signature.arg_names[:-1] + ): + raise ValueError("skip_dispatch parameter must be last") + return self._signature.arg_names[-1] == "__pyx_skip_dispatch" + + def numpy_arg_dtypes(self): + return self._signature.arg_dtypes + + def numpy_output_dtype(self): + return self._signature.res_dtype + + +def wrap_cython_function(func, restype, arg_types): + impls = _available_impls(func) + compatible = [] + for sig, capsule in impls: + if sig.provides(restype, arg_types): + compatible.append((sig, capsule)) + + def sort_key(args): + sig, _ = args + + # Prefer functions with less inputs bytes + argsize = sum(np.dtype(dtype).itemsize for dtype in sig.arg_dtypes) + + # Prefer functions with more exact (integer) arguments + num_inexact = sum(np.issubdtype(dtype, np.inexact) for dtype in sig.arg_dtypes) + return (num_inexact, argsize) + + compatible.sort(key=sort_key) + + if not compatible: + raise NotImplementedError(f"Could not find a compatible impl of {func}") + sig, capsule = compatible[0] + return _CythonWrapper(func, sig, capsule) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 23cc391810..569bd9b164 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -6,6 +6,7 @@ import numba import numpy as np +from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple from pytensor import config from pytensor.graph.basic import Apply @@ -27,6 +28,7 @@ OR, XOR, Add, + Composite, IntDiv, Mean, Mul, @@ -40,6 +42,7 @@ from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad +from pytensor.tensor.type import scalar @singledispatch @@ -216,6 +219,8 @@ def careduce_axis(x): """ + axis = normalize_axis_index(axis, ndim) + reduce_elemwise_fn_name = "careduce_axis" identity = str(identity) @@ -338,6 +343,8 @@ def careduce_maximum(input): if len(axes) == 1: return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype) + axes = normalize_axis_tuple(axes, ndim) + careduce_fn_name = f"careduce_{scalar_op}" global_env = {} to_reduce = reversed(sorted(axes)) @@ -407,6 +414,8 @@ def jit_compile_reducer(node, fn, **kwds): def create_axis_apply_fn(fn, axis, ndim, dtype): + axis = normalize_axis_index(axis, ndim) + reaxis_first = tuple(i for i in range(ndim) if i != axis) + (axis,) @numba_basic.numba_njit(boundscheck=False) @@ -424,8 +433,17 @@ def axis_apply_fn(x): @numba_funcify.register(Elemwise) def numba_funcify_Elemwise(op, node, **kwargs): - - scalar_op_fn = numba_funcify(op.scalar_op, node=node, inline="always", **kwargs) + # Creating a new scalar node is more involved and unnecessary + # if the scalar_op is composite, as the fgraph already contains + # all the necessary information. + scalar_node = None + if not isinstance(op.scalar_op, Composite): + scalar_inputs = [scalar(dtype=input.dtype) for input in node.inputs] + scalar_node = op.scalar_op.make_node(*scalar_inputs) + + scalar_op_fn = numba_funcify( + op.scalar_op, node=scalar_node, parent_node=node, inline="always", **kwargs + ) elemwise_fn = create_vectorize_func(scalar_op_fn, node, use_signature=False) elemwise_fn_name = elemwise_fn.__name__ @@ -599,6 +617,7 @@ def numba_funcify_Softmax(op, node, **kwargs): axis = op.axis if axis is not None: + axis = normalize_axis_index(axis, x_at.ndim) reduce_max_py = create_axis_reducer( scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True ) @@ -636,6 +655,7 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs): axis = op.axis if axis is not None: + axis = normalize_axis_index(axis, sm_at.ndim) reduce_sum_py = create_axis_reducer( add_as, 0.0, axis, sm_at.ndim, sm_dtype, keepdims=True ) @@ -667,6 +687,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): axis = op.axis if axis is not None: + axis = normalize_axis_index(axis, x_at.ndim) reduce_max_py = create_axis_reducer( scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True ) diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py index ad462cf58a..9871584454 100644 --- a/pytensor/link/numba/dispatch/extra_ops.py +++ b/pytensor/link/numba/dispatch/extra_ops.py @@ -7,6 +7,7 @@ from pytensor import config from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify +from pytensor.raise_op import CheckAndRaise from pytensor.tensor.extra_ops import ( Bartlett, BroadcastTo, @@ -36,31 +37,62 @@ def numba_funcify_CumOp(op, node, **kwargs): mode = op.mode ndim = node.outputs[0].ndim + if axis < 0: + axis = ndim + axis + if axis < 0 or axis >= ndim: + raise ValueError(f"Invalid axis {axis} for array with ndim {ndim}") + reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis) + reaxis_first_inv = tuple(np.argsort(reaxis_first)) if mode == "add": - np_func = np.add - identity = 0 + + if ndim == 1: + + @numba_basic.numba_njit(fastmath=config.numba__fastmath) + def cumop(x): + return np.cumsum(x) + + else: + + @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) + def cumop(x): + out_dtype = x.dtype + if x.shape[axis] < 2: + return x.astype(out_dtype) + + x_axis_first = x.transpose(reaxis_first) + res = np.empty(x_axis_first.shape, dtype=out_dtype) + + res[0] = x_axis_first[0] + for m in range(1, x.shape[axis]): + res[m] = res[m - 1] + x_axis_first[m] + + return res.transpose(reaxis_first_inv) + else: - np_func = np.multiply - identity = 1 + if ndim == 1: + + @numba_basic.numba_njit(fastmath=config.numba__fastmath) + def cumop(x): + return np.cumprod(x) - @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) - def cumop(x): - out_dtype = x.dtype - if x.shape[axis] < 2: - return x.astype(out_dtype) + else: - x_axis_first = x.transpose(reaxis_first) - res = np.empty(x_axis_first.shape, dtype=out_dtype) + @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) + def cumop(x): + out_dtype = x.dtype + if x.shape[axis] < 2: + return x.astype(out_dtype) - for m in numba.prange(x.shape[axis]): - if m == 0: - np_func(identity, x_axis_first[m], res[m]) - else: - np_func(res[m - 1], x_axis_first[m], res[m]) + x_axis_first = x.transpose(reaxis_first) + res = np.empty(x_axis_first.shape, dtype=out_dtype) - return res.transpose(reaxis_first) + res[0] = x_axis_first[0] + for m in range(1, x.shape[axis]): + res[m] = res[m - 1] * x_axis_first[m] + + return res.transpose(reaxis_first) return cumop @@ -346,3 +378,18 @@ def broadcast_to(x, *shape): return np.broadcast_to(x, scalars_shape) return broadcast_to + + +@numba_funcify.register(CheckAndRaise) +def numba_funcify_CheckAndRaise(op, node, **kwargs): + error = op.exc_type + msg = op.msg + + @numba_basic.numba_njit + def check_and_raise(x, *conditions): + for cond in conditions: + if not cond: + raise error(msg) + return x + + return check_and_raise diff --git a/pytensor/link/numba/dispatch/nlinalg.py b/pytensor/link/numba/dispatch/nlinalg.py index aabc89b4d9..21fa34e1bb 100644 --- a/pytensor/link/numba/dispatch/nlinalg.py +++ b/pytensor/link/numba/dispatch/nlinalg.py @@ -25,31 +25,20 @@ def numba_funcify_SVD(op, node, **kwargs): full_matrices = op.full_matrices compute_uv = op.compute_uv + out_dtype = np.dtype(node.outputs[0].dtype) - if not compute_uv: - - warnings.warn( - ( - "Numba will use object mode to allow the " - "`compute_uv` argument to `numpy.linalg.svd`." - ), - UserWarning, - ) + inputs_cast = int_to_float_fn(node.inputs, out_dtype) - ret_sig = get_numba_type(node.outputs[0].type) + if not compute_uv: - @numba_basic.numba_njit + @numba_basic.numba_njit() def svd(x): - with numba.objmode(ret=ret_sig): - ret = np.linalg.svd(x, full_matrices, compute_uv) + _, ret, _ = np.linalg.svd(inputs_cast(x), full_matrices) return ret else: - out_dtype = node.outputs[0].type.numpy_dtype - inputs_cast = int_to_float_fn(node.inputs, out_dtype) - - @numba_basic.numba_njit(inline="always") + @numba_basic.numba_njit() def svd(x): return np.linalg.svd(inputs_cast(x), full_matrices) diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index 5b905a99e9..d6c68d3208 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -1,16 +1,18 @@ import math -from functools import reduce from typing import List import numpy as np -import scipy -import scipy.special from pytensor import config from pytensor.compile.ops import ViewOp from pytensor.graph.basic import Variable from pytensor.link.numba.dispatch import basic as numba_basic -from pytensor.link.numba.dispatch.basic import create_numba_signature, numba_funcify +from pytensor.link.numba.dispatch.basic import ( + create_numba_signature, + generate_fallback_impl, + numba_funcify, +) +from pytensor.link.numba.dispatch.cython_support import wrap_cython_function from pytensor.link.utils import ( compile_function_src, get_name_for_object, @@ -36,69 +38,83 @@ def numba_funcify_ScalarOp(op, node, **kwargs): # TODO: Do we need to cache these functions so that we don't end up # compiling the same Numba function over and over again? - scalar_func_name = op.nfunc_spec[0] + scalar_func_path = op.nfunc_spec[0] + scalar_func_numba = None - if scalar_func_name.startswith("scipy."): - func_package = scipy - scalar_func_name = scalar_func_name.split(".", 1)[-1] - else: - func_package = np + *module_path, scalar_func_name = scalar_func_path.split(".") + if not module_path: + # Assume it is numpy, and numba has an implementation + scalar_func_numba = getattr(np, scalar_func_name) - if "." in scalar_func_name: - scalar_func = reduce(getattr, [scipy] + scalar_func_name.split(".")) - else: - scalar_func = getattr(func_package, scalar_func_name) + input_dtypes = [np.dtype(input.type.dtype) for input in node.inputs] + output_dtypes = [np.dtype(output.type.dtype) for output in node.outputs] + + if len(output_dtypes) != 1: + raise ValueError("ScalarOps with more than one output are not supported") + + output_dtype = output_dtypes[0] + + input_inner_dtypes = None + output_inner_dtype = None + + # Cython functions might have an additonal argument + has_pyx_skip_dispatch = False + + if scalar_func_path.startswith("scipy.special"): + import scipy.special.cython_special + + cython_func = getattr(scipy.special.cython_special, scalar_func_name, None) + if cython_func is not None: + # try: + scalar_func_numba = wrap_cython_function( + cython_func, output_dtype, input_dtypes + ) + has_pyx_skip_dispatch = scalar_func_numba.has_pyx_skip_dispatch + input_inner_dtypes = scalar_func_numba.numpy_arg_dtypes() + output_inner_dtype = scalar_func_numba.numpy_output_dtype() + # except NotImplementedError: + # pass - scalar_op_fn_name = get_name_for_object(scalar_func) + if scalar_func_numba is None: + scalar_func_numba = generate_fallback_impl(op, node, **kwargs) + + scalar_op_fn_name = get_name_for_object(scalar_func_numba) unique_names = unique_name_generator( - [scalar_op_fn_name, "scalar_func"], suffix_sep="_" + [scalar_op_fn_name, "scalar_func_numba"], suffix_sep="_" ) - global_env = {"scalar_func": scalar_func} + global_env = {"scalar_func_numba": scalar_func_numba} - input_tmp_dtypes = None - if func_package == scipy and hasattr(scalar_func, "types"): - # The `numba-scipy` bindings don't provide implementations for all - # inputs types, so we need to convert the inputs to floats and back. - inp_dtype_kinds = tuple(np.dtype(inp.type.dtype).kind for inp in node.inputs) - accepted_inp_kinds = tuple( - sig_type.split("->")[0] for sig_type in scalar_func.types - ) - if not any( - all(dk == ik for dk, ik in zip(inp_dtype_kinds, ok_kinds)) - for ok_kinds in accepted_inp_kinds - ): - # They're usually ordered from lower-to-higher precision, so - # we pick the last acceptable input types - # - # XXX: We should pick the first acceptable float/int types in - # reverse, excluding all the incompatible ones (e.g. `"0"`). - # The assumption is that this is only used by `numba-scipy`-exposed - # functions, although it's possible for this to be triggered by - # something else from the `scipy` package - input_tmp_dtypes = tuple(np.dtype(k) for k in accepted_inp_kinds[-1]) - - if input_tmp_dtypes is None: + if input_inner_dtypes is None and output_inner_dtype is None: unique_names = unique_name_generator( - [scalar_op_fn_name, "scalar_func"], suffix_sep="_" + [scalar_op_fn_name, "scalar_func_numba"], suffix_sep="_" ) input_names = ", ".join( [unique_names(v, force_unique=True) for v in node.inputs] ) - scalar_op_src = f""" + if not has_pyx_skip_dispatch: + scalar_op_src = f""" def {scalar_op_fn_name}({input_names}): - return scalar_func({input_names}) - """ + return scalar_func_numba({input_names}) + """ + else: + scalar_op_src = f""" +def {scalar_op_fn_name}({input_names}): + return scalar_func_numba({input_names}, np.intc(1)) + """ + else: global_env["direct_cast"] = numba_basic.direct_cast - global_env["output_dtype"] = np.dtype(node.outputs[0].type.dtype) + global_env["output_dtype"] = np.dtype(output_inner_dtype) input_tmp_dtype_names = { - f"inp_tmp_dtype_{i}": i_dtype for i, i_dtype in enumerate(input_tmp_dtypes) + f"inp_tmp_dtype_{i}": i_dtype + for i, i_dtype in enumerate(input_inner_dtypes) } global_env.update(input_tmp_dtype_names) unique_names = unique_name_generator( - [scalar_op_fn_name, "scalar_func"] + list(global_env.keys()), suffix_sep="_" + [scalar_op_fn_name, "scalar_func_numba"] + list(global_env.keys()), + suffix_sep="_", ) input_names = [unique_names(v, force_unique=True) for v in node.inputs] @@ -110,10 +126,16 @@ def {scalar_op_fn_name}({input_names}): ) ] ) - scalar_op_src = f""" + if not has_pyx_skip_dispatch: + scalar_op_src = f""" def {scalar_op_fn_name}({', '.join(input_names)}): - return direct_cast(scalar_func({converted_call_args}), output_dtype) - """ + return direct_cast(scalar_func_numba({converted_call_args}), output_dtype) + """ + else: + scalar_op_src = f""" +def {scalar_op_fn_name}({', '.join(input_names)}): + return direct_cast(scalar_func_numba({converted_call_args}, np.intc(1)), output_dtype) + """ scalar_op_fn = compile_function_src( scalar_op_src, scalar_op_fn_name, {**globals(), **global_env} @@ -122,7 +144,10 @@ def {scalar_op_fn_name}({', '.join(input_names)}): signature = create_numba_signature(node, force_scalar=True) return numba_basic.numba_njit( - signature, inline="always", fastmath=config.numba__fastmath + signature, + inline="always", + fastmath=config.numba__fastmath, + cache=False, )(scalar_op_fn) @@ -220,7 +245,7 @@ def clip(_x, _min, _max): @numba_funcify.register(Composite) def numba_funcify_Composite(op, node, **kwargs): - signature = create_numba_signature(node, force_scalar=True) + signature = create_numba_signature(op.fgraph, force_scalar=True) _ = kwargs.pop("storage_map", None) diff --git a/pytensor/sparse/sandbox/sp.py b/pytensor/sparse/sandbox/sp.py index 9b43e74c39..33b7de4665 100644 --- a/pytensor/sparse/sandbox/sp.py +++ b/pytensor/sparse/sandbox/sp.py @@ -154,17 +154,17 @@ def evaluate(inshp, kshp, strides=(1, 1), nkern=1, mode="valid", ws=True): # FOR EACH OUTPUT PIXEL... # loop over output image height - for oy in np.arange(lbound[0], ubound[0], dy): + for oy in np.arange(lbound[0], ubound[0], dy, dtype=int): # loop over output image width - for ox in np.arange(lbound[1], ubound[1], dx): + for ox in np.arange(lbound[1], ubound[1], dx, dtype=int): # kern[l] is filter value to apply at (oj,oi) # for (iy,ix) l = 0 # noqa: E741 # ... ITERATE OVER INPUT UNITS IN RECEPTIVE FIELD - for ky in oy + np.arange(kshp[0]): - for kx in ox + np.arange(kshp[1]): + for ky in oy + np.arange(kshp[0], dtype=int): + for kx in ox + np.arange(kshp[1], dtype=int): # verify if we are still within image # boundaries. Equivalent to @@ -176,13 +176,15 @@ def evaluate(inshp, kshp, strides=(1, 1), nkern=1, mode="valid", ws=True): # convert to "valid" input space # coords used to determine column # index to write to in sparse mat - iy, ix = np.array((ky, kx)) - topleft + iy, ix = np.array((ky, kx), dtype=int) - topleft # determine raster-index of input pixel... # taking into account multiple # input features col = int( - iy * inshp[2] + ix + fmapi * np.prod(inshp[1:]) + iy * inshp[2] + + ix + + fmapi * np.prod(inshp[1:], dtype=int) ) # convert oy,ox values to output @@ -192,7 +194,7 @@ def evaluate(inshp, kshp, strides=(1, 1), nkern=1, mode="valid", ws=True): else: (y, x) = (oy, ox) - topleft # taking into account step size - (y, x) = np.array([y, x]) / (dy, dx) + (y, x) = np.array([y, x], dtype=int) / (dy, dx) # convert to row index of sparse matrix if ws: @@ -212,7 +214,7 @@ def evaluate(inshp, kshp, strides=(1, 1), nkern=1, mode="valid", ws=True): # onto the sparse columns (idea of # kernel map) # n*... only for sparse - spmat[row + n * outsize, col] = tapi + 1 + spmat[int(row + n * outsize), int(col)] = tapi + 1 # total number of active taps # (used for kmap) diff --git a/pytensor/tensor/nnet/corr.py b/pytensor/tensor/nnet/corr.py index c747fb9f7c..6051a03b9b 100644 --- a/pytensor/tensor/nnet/corr.py +++ b/pytensor/tensor/nnet/corr.py @@ -692,7 +692,7 @@ def make_node(self, img, kern): if kern.type.ndim != 4: raise TypeError("kern must be 4D tensor") - out_shape = tuple( + out_shape = ( 1 if img.type.shape[0] == 1 else None, 1 if kern.type.shape[0] == 1 else None, None, diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 94cd33fe5c..4768ba77c2 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -681,9 +681,7 @@ def test_perform_params(): out = [out] out_fg = FunctionGraph([x], out) - - with pytest.warns(UserWarning, match=".*object mode.*"): - compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs]) + compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs]) def test_perform_type_convert(): @@ -702,9 +700,7 @@ def test_perform_type_convert(): out = [out] out_fg = FunctionGraph([x], out) - - with pytest.warns(UserWarning, match=".*object mode.*"): - compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs]) + compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs]) @pytest.mark.parametrize( diff --git a/tests/link/numba/test_cython_support.py b/tests/link/numba/test_cython_support.py new file mode 100644 index 0000000000..b96a22098f --- /dev/null +++ b/tests/link/numba/test_cython_support.py @@ -0,0 +1,92 @@ +import numpy as np +import pytest +import scipy.special.cython_special +from numba.types import float32, float64, int32, int64 + +from pytensor.link.numba.dispatch.cython_support import Signature, wrap_cython_function + + +@pytest.mark.parametrize( + "sig, expected_result, expected_args", + [ + (b"double(double)", np.float64, [np.float64]), + (b"float(unsigned int)", np.float32, [np.uintc]), + (b"unsigned char(unsigned short foo)", np.ubyte, [np.ushort]), + ( + b"unsigned char(unsigned short foo, double bar)", + np.ubyte, + [np.ushort, np.float64], + ), + ], +) +def test_parse_signature(sig, expected_result, expected_args): + actual = Signature.from_c_types(sig) + assert actual.res_dtype == expected_result + assert actual.arg_dtypes == expected_args + + +@pytest.mark.parametrize( + "have, want, should_provide", + [ + (b"double(int)", b"float(int)", True), + (b"float(int)", b"double(int)", False), + (b"double(unsigned short)", b"double(unsigned char)", True), + (b"double(unsigned char)", b"double(short)", False), + (b"short(double)", b"int(double)", True), + (b"int(double)", b"short(double)", False), + (b"float(double, int)", b"float(double, short)", True), + ], +) +def test_signature_provides(have, want, should_provide): + have = Signature.from_c_types(have) + want = Signature.from_c_types(want) + provides = have.provides(want.res_dtype, want.arg_dtypes) + assert provides == should_provide + + +@pytest.mark.parametrize( + "func, output, inputs, expected", + [ + ( + scipy.special.cython_special.agm, + np.float64, + [np.float64, np.float64], + float64(float64, float64, int32), + ), + ( + scipy.special.cython_special.erfc, + np.float64, + [np.float64], + float64(float64, int32), + ), + ( + scipy.special.cython_special.expit, + np.float32, + [np.float32], + float32(float32, int32), + ), + ( + scipy.special.cython_special.expit, + np.float64, + [np.float64], + float64(float64, int32), + ), + ( + # expn doesn't have a float32 implementation + scipy.special.cython_special.expn, + np.float32, + [np.float32, np.float32], + float64(float64, float64, int32), + ), + ( + # We choose the integer implementation if possible + scipy.special.cython_special.expn, + np.float32, + [np.int64, np.float32], + float64(int64, float64, int32), + ), + ], +) +def test_choose_signature(func, output, inputs, expected): + wrapper = wrap_cython_function(func, output, inputs) + assert wrapper.signature() == expected diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 7f2fd0a67e..db0c6dadab 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -57,6 +57,12 @@ lambda x: at.erfc(x), None, ), + ( + [at.vector()], + [rng.standard_normal(100).astype(config.floatX)], + lambda x: at.erfcx(x), + None, + ), ( [at.vector() for i in range(4)], [rng.standard_normal(100).astype(config.floatX) for i in range(4)], diff --git a/tests/link/numba/test_extra_ops.py b/tests/link/numba/test_extra_ops.py index 4d50b5a9e0..8cf1fdc6bd 100644 --- a/tests/link/numba/test_extra_ops.py +++ b/tests/link/numba/test_extra_ops.py @@ -80,6 +80,13 @@ def test_BroadcastTo(x, shape): 1, "add", ), + ( + set_test_value( + at.dtensor3(), np.arange(30, dtype=config.floatX).reshape((2, 3, 5)) + ), + -1, + "add", + ), ( set_test_value( at.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2)) diff --git a/tests/link/numba/test_nlinalg.py b/tests/link/numba/test_nlinalg.py index 9261388332..3018b9a97a 100644 --- a/tests/link/numba/test_nlinalg.py +++ b/tests/link/numba/test_nlinalg.py @@ -477,7 +477,7 @@ def test_QRFull(x, mode, exc): ), True, False, - UserWarning, + None, ), ], )