From ded5cbbce869acfac84428990041344e1b7ddb78 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 20 Jul 2022 15:37:30 -0500 Subject: [PATCH 01/17] Use objmode in scipy.special without numba-scipy --- pytensor/configdefaults.py | 6 +++++ pytensor/link/numba/dispatch/basic.py | 11 +++++++--- pytensor/link/numba/dispatch/elemwise.py | 15 +++++++++++-- pytensor/link/numba/dispatch/scalar.py | 28 +++++++++++++++++++++--- tests/link/numba/test_elemwise.py | 6 +++++ 5 files changed, 58 insertions(+), 8 deletions(-) diff --git a/pytensor/configdefaults.py b/pytensor/configdefaults.py index d6afff0fd5..0a6eae4c97 100644 --- a/pytensor/configdefaults.py +++ b/pytensor/configdefaults.py @@ -1252,6 +1252,12 @@ def add_numba_configvars(): BoolParam(True), in_c_key=False, ) + config.add( + "numba_scipy", + ("Enable usage of the numba_scipy package for special functions",), + BoolParam(True), + in_c_key=False, + ) def _default_compiledirname(): diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index c081fbe9ef..c08e95627d 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -319,9 +319,8 @@ def numba_typify(data, dtype=None, **kwargs): return data -@singledispatch -def numba_funcify(op, node=None, storage_map=None, **kwargs): - """Create a Numba compatible function from an PyTensor `Op`.""" +def generate_fallback_impl(op, node=None, storage_map=None, **kwargs): + """Create a Numba compatible function from an Aesara `Op`.""" warnings.warn( f"Numba will use object mode to run {op}'s perform method", @@ -375,6 +374,12 @@ def perform(*inputs): return perform +@singledispatch +def numba_funcify(op, node=None, storage_map=None, **kwargs): + """Generate a numba function for a given op and apply node.""" + return generate_fallback_impl(op, node, storage_map, **kwargs) + + @numba_funcify.register(OpFromGraph) def numba_funcify_OpFromGraph(op, node=None, **kwargs): diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 23cc391810..474bb67a1f 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -27,6 +27,7 @@ OR, XOR, Add, + Composite, IntDiv, Mean, Mul, @@ -40,6 +41,7 @@ from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad +from pytensor.tensor.type import scalar @singledispatch @@ -424,8 +426,17 @@ def axis_apply_fn(x): @numba_funcify.register(Elemwise) def numba_funcify_Elemwise(op, node, **kwargs): - - scalar_op_fn = numba_funcify(op.scalar_op, node=node, inline="always", **kwargs) + # Creating a new scalar node is more involved and unnecessary + # if the scalar_op is composite, as the fgraph already contains + # all the necessary information. + scalar_node = None + if not isinstance(op.scalar_op, Composite): + scalar_inputs = [scalar(dtype=input.dtype) for input in node.inputs] + scalar_node = op.scalar_op.make_node(*scalar_inputs) + + scalar_op_fn = numba_funcify( + op.scalar_op, node=scalar_node, parent_node=node, inline="always", **kwargs + ) elemwise_fn = create_vectorize_func(scalar_op_fn, node, use_signature=False) elemwise_fn_name = elemwise_fn.__name__ diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index 5b905a99e9..ce7dcbd830 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -1,4 +1,5 @@ import math +import warnings from functools import reduce from typing import List @@ -10,7 +11,11 @@ from pytensor.compile.ops import ViewOp from pytensor.graph.basic import Variable from pytensor.link.numba.dispatch import basic as numba_basic -from pytensor.link.numba.dispatch.basic import create_numba_signature, numba_funcify +from pytensor.link.numba.dispatch.basic import ( + create_numba_signature, + generate_fallback_impl, + numba_funcify, +) from pytensor.link.utils import ( compile_function_src, get_name_for_object, @@ -37,14 +42,31 @@ def numba_funcify_ScalarOp(op, node, **kwargs): # compiling the same Numba function over and over again? scalar_func_name = op.nfunc_spec[0] + scalar_func = None if scalar_func_name.startswith("scipy."): func_package = scipy scalar_func_name = scalar_func_name.split(".", 1)[-1] + + use_numba_scipy = config.numba_scipy + if use_numba_scipy: + try: + import numba_scipy # noqa: F401 + except ImportError: + use_numba_scipy = False + if not use_numba_scipy: + warnings.warn( + "Native numba versions of scipy functions might be " + "avalable if numba-scipy is installed.", + UserWarning, + ) + scalar_func = generate_fallback_impl(op, node, **kwargs) else: func_package = np - if "." in scalar_func_name: + if scalar_func is not None: + pass + elif "." in scalar_func_name: scalar_func = reduce(getattr, [scipy] + scalar_func_name.split(".")) else: scalar_func = getattr(func_package, scalar_func_name) @@ -220,7 +242,7 @@ def clip(_x, _min, _max): @numba_funcify.register(Composite) def numba_funcify_Composite(op, node, **kwargs): - signature = create_numba_signature(node, force_scalar=True) + signature = create_numba_signature(op.fgraph, force_scalar=True) _ = kwargs.pop("storage_map", None) diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 7f2fd0a67e..db0c6dadab 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -57,6 +57,12 @@ lambda x: at.erfc(x), None, ), + ( + [at.vector()], + [rng.standard_normal(100).astype(config.floatX)], + lambda x: at.erfcx(x), + None, + ), ( [at.vector() for i in range(4)], [rng.standard_normal(100).astype(config.floatX) for i in range(4)], From a156c10364d3ab76b77884934ab7da6ba4ea432e Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 27 Jul 2022 13:12:44 -0500 Subject: [PATCH 02/17] Replace numba_scipy --- pytensor/configdefaults.py | 6 - .../link/numba/dispatch/cython_support.py | 211 ++++++++++++++++++ pytensor/link/numba/dispatch/scalar.py | 134 +++++------ tests/link/numba/test_cython_support.py | 92 ++++++++ 4 files changed, 370 insertions(+), 73 deletions(-) create mode 100644 pytensor/link/numba/dispatch/cython_support.py create mode 100644 tests/link/numba/test_cython_support.py diff --git a/pytensor/configdefaults.py b/pytensor/configdefaults.py index 0a6eae4c97..d6afff0fd5 100644 --- a/pytensor/configdefaults.py +++ b/pytensor/configdefaults.py @@ -1252,12 +1252,6 @@ def add_numba_configvars(): BoolParam(True), in_c_key=False, ) - config.add( - "numba_scipy", - ("Enable usage of the numba_scipy package for special functions",), - BoolParam(True), - in_c_key=False, - ) def _default_compiledirname(): diff --git a/pytensor/link/numba/dispatch/cython_support.py b/pytensor/link/numba/dispatch/cython_support.py new file mode 100644 index 0000000000..ab551bb1d3 --- /dev/null +++ b/pytensor/link/numba/dispatch/cython_support.py @@ -0,0 +1,211 @@ +import ctypes +import importlib +import re +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, cast + +import numba +import numpy as np +from numpy.typing import DTypeLike +from scipy import LowLevelCallable + + +_C_TO_NUMPY: Dict[str, DTypeLike] = { + "bool": np.bool_, + "signed char": np.byte, + "unsigned char": np.ubyte, + "short": np.short, + "unsigned short": np.ushort, + "int": np.intc, + "unsigned int": np.uintc, + "long": np.int_, + "unsigned long": np.uint, + "long long": np.longlong, + "float": np.single, + "double": np.double, + "long double": np.longdouble, + "float complex": np.csingle, + "double complex": np.cdouble, +} + + +@dataclass +class Signature: + res_dtype: DTypeLike + res_c_type: str + arg_dtypes: List[DTypeLike] + arg_c_types: List[str] + arg_names: List[Optional[str]] + + @property + def arg_numba_types(self) -> List[DTypeLike]: + return [numba.from_dtype(dtype) for dtype in self.arg_dtypes] + + def can_cast_args(self, args: List[DTypeLike]) -> bool: + ok = True + count = 0 + for name, dtype in zip(self.arg_names, self.arg_dtypes): + if name == "__pyx_skip_dispatch": + continue + if len(args) <= count: + raise ValueError("Incorrect number of arguments") + ok &= np.can_cast(args[count], dtype) + count += 1 + if count != len(args): + return False + return ok + + def provides(self, restype: DTypeLike, arg_dtypes: List[DTypeLike]) -> bool: + args_ok = self.can_cast_args(arg_dtypes) + if np.issubdtype(restype, np.inexact): + result_ok = np.can_cast(self.res_dtype, restype, casting="same_kind") + # We do not want to provide less accuracy than advertised + result_ok &= np.dtype(self.res_dtype).itemsize >= np.dtype(restype).itemsize + else: + result_ok = np.can_cast(self.res_dtype, restype) + return args_ok and result_ok + + @staticmethod + def from_c_types(signature: bytes) -> "Signature": + # Match strings like "double(int, double)" + # and extract the return type and the joined arguments + expr = re.compile(rb"\s*(?P[\w ]*\w+)\s*\((?P[\w\s,]*)\)") + re_match = re.fullmatch(expr, signature) + + if re_match is None: + raise ValueError(f"Invalid signature: {signature.decode()}") + + groups = re_match.groupdict() + res_c_type = groups["restype"].decode() + res_dtype: DTypeLike = _C_TO_NUMPY[res_c_type] + + raw_args = groups["args"] + + decl_expr = re.compile( + rb"\s*(?P((long )|(unsigned )|(signed )|(double )|)" + rb"((double)|(float)|(int)|(short)|(char)|(long)|(bool)|(complex)))" + rb"(\s(?P[\w_]*))?\s*" + ) + + arg_dtypes = [] + arg_names: List[Optional[str]] = [] + arg_c_types = [] + for raw_arg in raw_args.split(b","): + re_match = re.fullmatch(decl_expr, raw_arg) + if re_match is None: + raise ValueError(f"Invalid signature: {signature.decode()}") + groups = re_match.groupdict() + arg_c_type = groups["type"].decode() + try: + arg_dtype = _C_TO_NUMPY[arg_c_type] + except KeyError: + raise ValueError(f"Unknown C type: {arg_c_type}") + + arg_c_types.append(arg_c_type) + arg_dtypes.append(arg_dtype) + name = groups["name"] + if not name: + arg_names.append(None) + else: + arg_names.append(name.decode()) + + return Signature(res_dtype, res_c_type, arg_dtypes, arg_c_types, arg_names) + + +def _available_impls(func: Callable) -> List[Tuple[Signature, Any]]: + """Find all available implementations for a fused cython function.""" + impls = [] + mod = importlib.import_module(func.__module__) + + signatures = getattr(func, "__signatures__", None) + if signatures is not None: + # Cython function with __signatures__ should be fused and thus + # indexable + func_map = cast(Mapping, func) + candidates = [func_map[key] for key in signatures] + else: + candidates = [func] + for candidate in candidates: + name = candidate.__name__ + capsule = mod.__pyx_capi__[name] + llc = LowLevelCallable(capsule) + try: + signature = Signature.from_c_types(llc.signature.encode()) + except KeyError: + continue + impls.append((signature, capsule)) + return impls + + +class _CythonWrapper(numba.types.WrapperAddressProtocol): + def __init__(self, pyfunc, signature, capsule): + self._keep_alive = capsule + get_name = ctypes.pythonapi.PyCapsule_GetName + get_name.restype = ctypes.c_char_p + get_name.argtypes = (ctypes.py_object,) + + raw_signature = get_name(capsule) + + get_pointer = ctypes.pythonapi.PyCapsule_GetPointer + get_pointer.restype = ctypes.c_void_p + get_pointer.argtypes = (ctypes.py_object, ctypes.c_char_p) + self._func_ptr = get_pointer(capsule, raw_signature) + + self._signature = signature + self._pyfunc = pyfunc + + def signature(self): + return numba.from_dtype(self._signature.res_dtype)( + *self._signature.arg_numba_types + ) + + def __wrapper_address__(self): + return self._func_ptr + + def __call__(self, *args, **kwargs): + args = [dtype(arg) for arg, dtype in zip(args, self._signature.arg_dtypes)] + if self.has_pyx_skip_dispatch(): + output = self._pyfunc(*args[:-1], **kwargs) + else: + output = self._pyfunc(*args, **kwargs) + return self._signature.res_dtype(output) + + def has_pyx_skip_dispatch(self): + if not self._signature.arg_names: + return False + if any( + name == "__pyx_skip_dispatch" for name in self._signature.arg_names[:-1] + ): + raise ValueError("skip_dispatch parameter must be last") + return self._signature.arg_names[-1] == "__pyx_skip_dispatch" + + def numpy_arg_dtypes(self): + return self._signature.arg_dtypes + + def numpy_output_dtype(self): + return self._signature.res_dtype + + +def wrap_cython_function(func, restype, arg_types): + impls = _available_impls(func) + compatible = [] + for sig, capsule in impls: + if sig.provides(restype, arg_types): + compatible.append((sig, capsule)) + + def sort_key(args): + sig, _ = args + + # Prefer functions with less inputs bytes + argsize = sum(np.dtype(dtype).itemsize for dtype in sig.arg_dtypes) + + # Prefer functions with more exact (integer) arguments + num_inexact = sum(np.issubdtype(dtype, np.inexact) for dtype in sig.arg_dtypes) + return (num_inexact, argsize) + + compatible.sort(key=sort_key) + + if not compatible: + raise NotImplementedError(f"Could not find a compatible impl of {func}") + sig, capsule = compatible[0] + return _CythonWrapper(func, sig, capsule) diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index ce7dcbd830..36793393f7 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -1,11 +1,7 @@ import math -import warnings -from functools import reduce from typing import List import numpy as np -import scipy -import scipy.special from pytensor import config from pytensor.compile.ops import ViewOp @@ -16,6 +12,7 @@ generate_fallback_impl, numba_funcify, ) +from pytensor.link.numba.dispatch.cython_support import wrap_cython_function from pytensor.link.utils import ( compile_function_src, get_name_for_object, @@ -41,86 +38,83 @@ def numba_funcify_ScalarOp(op, node, **kwargs): # TODO: Do we need to cache these functions so that we don't end up # compiling the same Numba function over and over again? - scalar_func_name = op.nfunc_spec[0] - scalar_func = None - - if scalar_func_name.startswith("scipy."): - func_package = scipy - scalar_func_name = scalar_func_name.split(".", 1)[-1] - - use_numba_scipy = config.numba_scipy - if use_numba_scipy: - try: - import numba_scipy # noqa: F401 - except ImportError: - use_numba_scipy = False - if not use_numba_scipy: - warnings.warn( - "Native numba versions of scipy functions might be " - "avalable if numba-scipy is installed.", - UserWarning, + scalar_func_path = op.nfunc_spec[0] + scalar_func_numba = None + + *module_path, scalar_func_name = scalar_func_path.split(".") + if not module_path: + # Assume it is numpy, and numba has an implementation + scalar_func_numba = getattr(np, scalar_func_name) + + input_dtypes = [np.dtype(input.type.dtype) for input in node.inputs] + output_dtypes = [np.dtype(output.type.dtype) for output in node.outputs] + + if len(output_dtypes) != 1: + raise ValueError("ScalarOps with more than one output are not supported") + + output_dtype = output_dtypes[0] + + input_inner_dtypes = None + output_inner_dtype = None + + # Cython functions might have an additonal argument + has_pyx_skip_dispatch = False + + if scalar_func_path.startswith("scipy.special"): + import scipy.special.cython_special + + cython_func = getattr(scipy.special.cython_special, scalar_func_name, None) + if cython_func is not None: + # try: + scalar_func_numba = wrap_cython_function( + cython_func, output_dtype, input_dtypes ) - scalar_func = generate_fallback_impl(op, node, **kwargs) - else: - func_package = np + has_pyx_skip_dispatch = scalar_func_numba.has_pyx_skip_dispatch + input_inner_dtypes = scalar_func_numba.numpy_arg_dtypes() + output_inner_dtype = scalar_func_numba.numpy_output_dtype() + # except NotImplementedError: + # pass - if scalar_func is not None: - pass - elif "." in scalar_func_name: - scalar_func = reduce(getattr, [scipy] + scalar_func_name.split(".")) - else: - scalar_func = getattr(func_package, scalar_func_name) + if scalar_func_numba is None: + scalar_func_numba = generate_fallback_impl(op, node, **kwargs) - scalar_op_fn_name = get_name_for_object(scalar_func) + scalar_op_fn_name = get_name_for_object(scalar_func_numba) unique_names = unique_name_generator( - [scalar_op_fn_name, "scalar_func"], suffix_sep="_" + [scalar_op_fn_name, "scalar_func_numba"], suffix_sep="_" ) - global_env = {"scalar_func": scalar_func} + global_env = {"scalar_func_numba": scalar_func_numba} - input_tmp_dtypes = None - if func_package == scipy and hasattr(scalar_func, "types"): - # The `numba-scipy` bindings don't provide implementations for all - # inputs types, so we need to convert the inputs to floats and back. - inp_dtype_kinds = tuple(np.dtype(inp.type.dtype).kind for inp in node.inputs) - accepted_inp_kinds = tuple( - sig_type.split("->")[0] for sig_type in scalar_func.types - ) - if not any( - all(dk == ik for dk, ik in zip(inp_dtype_kinds, ok_kinds)) - for ok_kinds in accepted_inp_kinds - ): - # They're usually ordered from lower-to-higher precision, so - # we pick the last acceptable input types - # - # XXX: We should pick the first acceptable float/int types in - # reverse, excluding all the incompatible ones (e.g. `"0"`). - # The assumption is that this is only used by `numba-scipy`-exposed - # functions, although it's possible for this to be triggered by - # something else from the `scipy` package - input_tmp_dtypes = tuple(np.dtype(k) for k in accepted_inp_kinds[-1]) - - if input_tmp_dtypes is None: + if input_inner_dtypes is None and output_inner_dtype is None: unique_names = unique_name_generator( - [scalar_op_fn_name, "scalar_func"], suffix_sep="_" + [scalar_op_fn_name, "scalar_func_numba"], suffix_sep="_" ) input_names = ", ".join( [unique_names(v, force_unique=True) for v in node.inputs] ) - scalar_op_src = f""" + if not has_pyx_skip_dispatch: + scalar_op_src = f""" +def {scalar_op_fn_name}({input_names}): + return scalar_func_numba({input_names}) + """ + else: + scalar_op_src = f""" def {scalar_op_fn_name}({input_names}): - return scalar_func({input_names}) - """ + return scalar_func_numba({input_names}, np.intc(1)) + """ + else: global_env["direct_cast"] = numba_basic.direct_cast - global_env["output_dtype"] = np.dtype(node.outputs[0].type.dtype) + global_env["output_dtype"] = np.dtype(output_inner_dtype) input_tmp_dtype_names = { - f"inp_tmp_dtype_{i}": i_dtype for i, i_dtype in enumerate(input_tmp_dtypes) + f"inp_tmp_dtype_{i}": i_dtype + for i, i_dtype in enumerate(input_inner_dtypes) } global_env.update(input_tmp_dtype_names) unique_names = unique_name_generator( - [scalar_op_fn_name, "scalar_func"] + list(global_env.keys()), suffix_sep="_" + [scalar_op_fn_name, "scalar_func_numba"] + list(global_env.keys()), + suffix_sep="_", ) input_names = [unique_names(v, force_unique=True) for v in node.inputs] @@ -132,10 +126,16 @@ def {scalar_op_fn_name}({input_names}): ) ] ) - scalar_op_src = f""" + if not has_pyx_skip_dispatch: + scalar_op_src = f""" +def {scalar_op_fn_name}({', '.join(input_names)}): + return direct_cast(scalar_func_numba({converted_call_args}), output_dtype) + """ + else: + scalar_op_src = f""" def {scalar_op_fn_name}({', '.join(input_names)}): - return direct_cast(scalar_func({converted_call_args}), output_dtype) - """ + return direct_cast(scalar_func_numba({converted_call_args}, np.intc(1)), output_dtype) + """ scalar_op_fn = compile_function_src( scalar_op_src, scalar_op_fn_name, {**globals(), **global_env} diff --git a/tests/link/numba/test_cython_support.py b/tests/link/numba/test_cython_support.py new file mode 100644 index 0000000000..c119b2fb62 --- /dev/null +++ b/tests/link/numba/test_cython_support.py @@ -0,0 +1,92 @@ +import numpy as np +import pytest +import scipy.special.cython_special +from numba.types import float32, float64, int32, int64 + +from aesara.link.numba.dispatch.cython_support import Signature, wrap_cython_function + + +@pytest.mark.parametrize( + "sig, expected_result, expected_args", + [ + (b"double(double)", np.float64, [np.float64]), + (b"float(unsigned int)", np.float32, [np.uintc]), + (b"unsigned char(unsigned short foo)", np.ubyte, [np.ushort]), + ( + b"unsigned char(unsigned short foo, double bar)", + np.ubyte, + [np.ushort, np.float64], + ), + ], +) +def test_parse_signature(sig, expected_result, expected_args): + actual = Signature.from_c_types(sig) + assert actual.res_dtype == expected_result + assert actual.arg_dtypes == expected_args + + +@pytest.mark.parametrize( + "have, want, should_provide", + [ + (b"double(int)", b"float(int)", True), + (b"float(int)", b"double(int)", False), + (b"double(unsigned short)", b"double(unsigned char)", True), + (b"double(unsigned char)", b"double(short)", False), + (b"short(double)", b"int(double)", True), + (b"int(double)", b"short(double)", False), + (b"float(double, int)", b"float(double, short)", True), + ], +) +def test_signature_provides(have, want, should_provide): + have = Signature.from_c_types(have) + want = Signature.from_c_types(want) + provides = have.provides(want.res_dtype, want.arg_dtypes) + assert provides == should_provide + + +@pytest.mark.parametrize( + "func, output, inputs, expected", + [ + ( + scipy.special.cython_special.agm, + np.float64, + [np.float64, np.float64], + float64(float64, float64, int32), + ), + ( + scipy.special.cython_special.erfc, + np.float64, + [np.float64], + float64(float64, int32), + ), + ( + scipy.special.cython_special.expit, + np.float32, + [np.float32], + float32(float32, int32), + ), + ( + scipy.special.cython_special.expit, + np.float64, + [np.float64], + float64(float64, int32), + ), + ( + # expn doesn't have a float32 implementation + scipy.special.cython_special.expn, + np.float32, + [np.float32, np.float32], + float64(float64, float64, int32), + ), + ( + # We choose the integer implementation if possible + scipy.special.cython_special.expn, + np.float32, + [np.int64, np.float32], + float64(int64, float64, int32), + ), + ], +) +def test_choose_signature(func, output, inputs, expected): + wrapper = wrap_cython_function(func, output, inputs) + assert wrapper.signature() == expected From d47c49e948ca47374909f568305958bc9e1f73cd Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Fri, 11 Nov 2022 15:41:50 -0600 Subject: [PATCH 03/17] Fix numba impl of CumOp --- pytensor/link/numba/dispatch/extra_ops.py | 60 ++++++++++++++++------- 1 file changed, 43 insertions(+), 17 deletions(-) diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py index ad462cf58a..77cfffc05f 100644 --- a/pytensor/link/numba/dispatch/extra_ops.py +++ b/pytensor/link/numba/dispatch/extra_ops.py @@ -36,31 +36,57 @@ def numba_funcify_CumOp(op, node, **kwargs): mode = op.mode ndim = node.outputs[0].ndim + if axis < 0: + axis = ndim + axis + if axis < 0 or axis >= ndim: + raise ValueError(f"Invalid axis {axis} for array with ndim {ndim}") + reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis) if mode == "add": - np_func = np.add - identity = 0 + + if ndim == 1: + @numba_basic.numba_njit(fastmath=config.numba__fastmath) + def cumop(x): + return np.cumsum(x) + + else: + @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) + def cumop(x): + out_dtype = x.dtype + if x.shape[axis] < 2: + return x.astype(out_dtype) + + x_axis_first = x.transpose(reaxis_first) + res = np.empty(x_axis_first.shape, dtype=out_dtype) + + res[0] = x_axis_first[0] + for m in range(1, x.shape[axis]): + res[m] = res[m - 1] + x_axis_first[m] + + return res.transpose(reaxis_first) + else: - np_func = np.multiply - identity = 1 + if ndim == 1: + @numba_basic.numba_njit(fastmath=config.numba__fastmath) + def cumop(x): + return np.cumprod(x) - @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) - def cumop(x): - out_dtype = x.dtype - if x.shape[axis] < 2: - return x.astype(out_dtype) + else: + @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) + def cumop(x): + out_dtype = x.dtype + if x.shape[axis] < 2: + return x.astype(out_dtype) - x_axis_first = x.transpose(reaxis_first) - res = np.empty(x_axis_first.shape, dtype=out_dtype) + x_axis_first = x.transpose(reaxis_first) + res = np.empty(x_axis_first.shape, dtype=out_dtype) - for m in numba.prange(x.shape[axis]): - if m == 0: - np_func(identity, x_axis_first[m], res[m]) - else: - np_func(res[m - 1], x_axis_first[m], res[m]) + res[0] = x_axis_first[0] + for m in range(1, x.shape[axis]): + res[m] = res[m - 1] * x_axis_first[m] - return res.transpose(reaxis_first) + return res.transpose(reaxis_first) return cumop From 0c44e08add68e9e5f44a4431e0cf52684b3c4f7d Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Fri, 11 Nov 2022 15:42:29 -0600 Subject: [PATCH 04/17] Allow svd(compute_uv=False) in numba --- pytensor/link/numba/dispatch/nlinalg.py | 23 ++++++----------------- tests/link/numba/test_nlinalg.py | 2 +- 2 files changed, 7 insertions(+), 18 deletions(-) diff --git a/pytensor/link/numba/dispatch/nlinalg.py b/pytensor/link/numba/dispatch/nlinalg.py index aabc89b4d9..21fa34e1bb 100644 --- a/pytensor/link/numba/dispatch/nlinalg.py +++ b/pytensor/link/numba/dispatch/nlinalg.py @@ -25,31 +25,20 @@ def numba_funcify_SVD(op, node, **kwargs): full_matrices = op.full_matrices compute_uv = op.compute_uv + out_dtype = np.dtype(node.outputs[0].dtype) - if not compute_uv: - - warnings.warn( - ( - "Numba will use object mode to allow the " - "`compute_uv` argument to `numpy.linalg.svd`." - ), - UserWarning, - ) + inputs_cast = int_to_float_fn(node.inputs, out_dtype) - ret_sig = get_numba_type(node.outputs[0].type) + if not compute_uv: - @numba_basic.numba_njit + @numba_basic.numba_njit() def svd(x): - with numba.objmode(ret=ret_sig): - ret = np.linalg.svd(x, full_matrices, compute_uv) + _, ret, _ = np.linalg.svd(inputs_cast(x), full_matrices) return ret else: - out_dtype = node.outputs[0].type.numpy_dtype - inputs_cast = int_to_float_fn(node.inputs, out_dtype) - - @numba_basic.numba_njit(inline="always") + @numba_basic.numba_njit() def svd(x): return np.linalg.svd(inputs_cast(x), full_matrices) diff --git a/tests/link/numba/test_nlinalg.py b/tests/link/numba/test_nlinalg.py index 9261388332..3018b9a97a 100644 --- a/tests/link/numba/test_nlinalg.py +++ b/tests/link/numba/test_nlinalg.py @@ -477,7 +477,7 @@ def test_QRFull(x, mode, exc): ), True, False, - UserWarning, + None, ), ], ) From 2e5e3b312ae198b2054f4d2a0b66301014d0f585 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Fri, 11 Nov 2022 15:43:21 -0600 Subject: [PATCH 05/17] Disable numba cache for cython functions --- pytensor/link/numba/dispatch/basic.py | 8 ++++++-- pytensor/link/numba/dispatch/scalar.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index c08e95627d..1fa344d7bb 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -48,10 +48,14 @@ def numba_njit(*args, **kwargs): + kwargs = kwargs.copy() + if "cache" not in kwargs: + kwargs["cache"] = config.numba__cache + if len(args) > 0 and callable(args[0]): - return numba.njit(*args[1:], cache=config.numba__cache, **kwargs)(args[0]) + return numba.njit(*args[1:], **kwargs)(args[0]) - return numba.njit(*args, cache=config.numba__cache, **kwargs) + return numba.njit(*args, **kwargs) def numba_vectorize(*args, **kwargs): diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index 36793393f7..8679b6fb8c 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -144,7 +144,7 @@ def {scalar_op_fn_name}({', '.join(input_names)}): signature = create_numba_signature(node, force_scalar=True) return numba_basic.numba_njit( - signature, inline="always", fastmath=config.numba__fastmath + signature, inline="always", fastmath=config.numba__fastmath, cache=False, )(scalar_op_fn) From 8c5a8d1126397e320df2f13d45c81b1796b835b8 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 15 Nov 2022 16:23:29 -0600 Subject: [PATCH 06/17] Remove broken AdvancedIndexing numba impls --- pytensor/link/numba/dispatch/basic.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 1fa344d7bb..197d197dc1 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -515,7 +515,6 @@ def {fn_name}({", ".join(input_names)}): @numba_funcify.register(Subtensor) -@numba_funcify.register(AdvancedSubtensor) @numba_funcify.register(AdvancedSubtensor1) def numba_funcify_Subtensor(op, node, **kwargs): @@ -533,7 +532,6 @@ def numba_funcify_Subtensor(op, node, **kwargs): @numba_funcify.register(IncSubtensor) -@numba_funcify.register(AdvancedIncSubtensor) def numba_funcify_IncSubtensor(op, node, **kwargs): incsubtensor_def_src = create_index_func( From c7ea15fb670503ac412d02a27af30975f6dcb396 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 15 Nov 2022 16:23:49 -0600 Subject: [PATCH 07/17] Normalize negative axes --- pytensor/link/numba/dispatch/elemwise.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 474bb67a1f..2ae8a3cec6 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -164,6 +164,18 @@ def create_vectorize_func( return elemwise_fn +def normalize_axis(axis, ndim): + if axis is None: + return axis + + if axis < 0: + axis = ndim + axis + + if axis < 0 or axis >= ndim: + raise np.AxisError(ndim=ndim, axis=axis) + return axis + + def create_axis_reducer( scalar_op: Op, identity: Union[np.ndarray, Number], @@ -218,6 +230,8 @@ def careduce_axis(x): """ + axis = normalize_axis(axis, ndim) + reduce_elemwise_fn_name = "careduce_axis" identity = str(identity) @@ -340,6 +354,8 @@ def careduce_maximum(input): if len(axes) == 1: return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype) + axes = [normalize_axis(axis, ndim) for axis in axes] + careduce_fn_name = f"careduce_{scalar_op}" global_env = {} to_reduce = reversed(sorted(axes)) @@ -409,6 +425,8 @@ def jit_compile_reducer(node, fn, **kwds): def create_axis_apply_fn(fn, axis, ndim, dtype): + axis = normalize_axis(axis, ndim) + reaxis_first = tuple(i for i in range(ndim) if i != axis) + (axis,) @numba_basic.numba_njit(boundscheck=False) @@ -609,6 +627,8 @@ def numba_funcify_Softmax(op, node, **kwargs): x_dtype = numba.np.numpy_support.from_dtype(x_dtype) axis = op.axis + axis = normalize_axis(axis, x_at.ndim) + if axis is not None: reduce_max_py = create_axis_reducer( scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True @@ -646,6 +666,7 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs): sm_dtype = numba.np.numpy_support.from_dtype(sm_dtype) axis = op.axis + axis = normalize_axis(axis, sm_at.ndim) if axis is not None: reduce_sum_py = create_axis_reducer( add_as, 0.0, axis, sm_at.ndim, sm_dtype, keepdims=True @@ -676,6 +697,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): x_dtype = x_at.type.numpy_dtype x_dtype = numba.np.numpy_support.from_dtype(x_dtype) axis = op.axis + axis = normalize_axis(axis, x_at.ndim) if axis is not None: reduce_max_py = create_axis_reducer( From 85711a89c21eca710099b6a403204e13f6d075b0 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 15 Nov 2022 16:24:08 -0600 Subject: [PATCH 08/17] Add numba impl for CheckAndRaise --- pytensor/link/numba/dispatch/extra_ops.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py index 77cfffc05f..b072d93c34 100644 --- a/pytensor/link/numba/dispatch/extra_ops.py +++ b/pytensor/link/numba/dispatch/extra_ops.py @@ -19,6 +19,7 @@ Unique, UnravelIndex, ) +from aesara.raise_op import CheckAndRaise @numba_funcify.register(Bartlett) @@ -372,3 +373,18 @@ def broadcast_to(x, *shape): return np.broadcast_to(x, scalars_shape) return broadcast_to + + +@numba_funcify.register(CheckAndRaise) +def numba_funcify_CheckAndRaise(op, node, **kwargs): + error = op.exc_type + msg = op.msg + + @numba_basic.numba_njit + def check_and_raise(x, *conditions): + for cond in conditions: + if not cond: + raise error(msg) + return x + + return check_and_raise From 3f5d4644837c58207cb8c4832854b099eee80b4a Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 16 Nov 2022 17:26:21 -0600 Subject: [PATCH 09/17] Fix some sandbox tests --- pytensor/sparse/sandbox/sp.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pytensor/sparse/sandbox/sp.py b/pytensor/sparse/sandbox/sp.py index 9b43e74c39..44f8126b9b 100644 --- a/pytensor/sparse/sandbox/sp.py +++ b/pytensor/sparse/sandbox/sp.py @@ -154,17 +154,17 @@ def evaluate(inshp, kshp, strides=(1, 1), nkern=1, mode="valid", ws=True): # FOR EACH OUTPUT PIXEL... # loop over output image height - for oy in np.arange(lbound[0], ubound[0], dy): + for oy in np.arange(lbound[0], ubound[0], dy, dtype=int): # loop over output image width - for ox in np.arange(lbound[1], ubound[1], dx): + for ox in np.arange(lbound[1], ubound[1], dx, dtype=int): # kern[l] is filter value to apply at (oj,oi) # for (iy,ix) l = 0 # noqa: E741 # ... ITERATE OVER INPUT UNITS IN RECEPTIVE FIELD - for ky in oy + np.arange(kshp[0]): - for kx in ox + np.arange(kshp[1]): + for ky in oy + np.arange(kshp[0], dtype=int): + for kx in ox + np.arange(kshp[1], dtype=int): # verify if we are still within image # boundaries. Equivalent to @@ -176,13 +176,13 @@ def evaluate(inshp, kshp, strides=(1, 1), nkern=1, mode="valid", ws=True): # convert to "valid" input space # coords used to determine column # index to write to in sparse mat - iy, ix = np.array((ky, kx)) - topleft + iy, ix = np.array((ky, kx), dtype=int) - topleft # determine raster-index of input pixel... # taking into account multiple # input features col = int( - iy * inshp[2] + ix + fmapi * np.prod(inshp[1:]) + iy * inshp[2] + ix + fmapi * np.prod(inshp[1:], dtype=int) ) # convert oy,ox values to output @@ -192,7 +192,7 @@ def evaluate(inshp, kshp, strides=(1, 1), nkern=1, mode="valid", ws=True): else: (y, x) = (oy, ox) - topleft # taking into account step size - (y, x) = np.array([y, x]) / (dy, dx) + (y, x) = np.array([y, x], dtype=int) / (dy, dx) # convert to row index of sparse matrix if ws: @@ -212,7 +212,7 @@ def evaluate(inshp, kshp, strides=(1, 1), nkern=1, mode="valid", ws=True): # onto the sparse columns (idea of # kernel map) # n*... only for sparse - spmat[row + n * outsize, col] = tapi + 1 + spmat[int(row + n * outsize), int(col)] = tapi + 1 # total number of active taps # (used for kmap) From b5debc7162f453ec3c4859c54c0046d0197f3524 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 16 Nov 2022 17:26:36 -0600 Subject: [PATCH 10/17] Fix test failures in numba mode --- tests/link/numba/test_basic.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 94cd33fe5c..4768ba77c2 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -681,9 +681,7 @@ def test_perform_params(): out = [out] out_fg = FunctionGraph([x], out) - - with pytest.warns(UserWarning, match=".*object mode.*"): - compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs]) + compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs]) def test_perform_type_convert(): @@ -702,9 +700,7 @@ def test_perform_type_convert(): out = [out] out_fg = FunctionGraph([x], out) - - with pytest.warns(UserWarning, match=".*object mode.*"): - compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs]) + compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs]) @pytest.mark.parametrize( From 5db483369d3ae82c1d84c250dd81309983637072 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Fri, 18 Nov 2022 11:37:43 -0600 Subject: [PATCH 11/17] Fix usage of deprecated llvmpy core --- pytensor/link/numba/dispatch/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 197d197dc1..8cbfea5658 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -10,7 +10,7 @@ import numpy as np import scipy import scipy.special -from llvmlite.llvmpy.core import Type as llvm_Type +from llvmlite.ir import Type as llvm_Type from numba import types from numba.core.errors import TypingError from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 From 8cbd67871c021c737d2544c516acaadab0bc0385 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Fri, 18 Nov 2022 13:31:57 -0600 Subject: [PATCH 12/17] Fix bug in numba impl of cumsum --- pytensor/link/numba/dispatch/extra_ops.py | 3 ++- tests/link/numba/test_extra_ops.py | 7 +++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py index b072d93c34..34247d5779 100644 --- a/pytensor/link/numba/dispatch/extra_ops.py +++ b/pytensor/link/numba/dispatch/extra_ops.py @@ -43,6 +43,7 @@ def numba_funcify_CumOp(op, node, **kwargs): raise ValueError(f"Invalid axis {axis} for array with ndim {ndim}") reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis) + reaxis_first_inv = tuple(np.argsort(reaxis_first)) if mode == "add": @@ -65,7 +66,7 @@ def cumop(x): for m in range(1, x.shape[axis]): res[m] = res[m - 1] + x_axis_first[m] - return res.transpose(reaxis_first) + return res.transpose(reaxis_first_inv) else: if ndim == 1: diff --git a/tests/link/numba/test_extra_ops.py b/tests/link/numba/test_extra_ops.py index 4d50b5a9e0..8cf1fdc6bd 100644 --- a/tests/link/numba/test_extra_ops.py +++ b/tests/link/numba/test_extra_ops.py @@ -80,6 +80,13 @@ def test_BroadcastTo(x, shape): 1, "add", ), + ( + set_test_value( + at.dtensor3(), np.arange(30, dtype=config.floatX).reshape((2, 3, 5)) + ), + -1, + "add", + ), ( set_test_value( at.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2)) From 95e2e64de85cbe5411c1064ffc3e7c9df30ebbae Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 29 Nov 2022 23:38:26 -0600 Subject: [PATCH 13/17] Some leftover renames --- pytensor/link/numba/dispatch/extra_ops.py | 2 +- tests/link/numba/test_cython_support.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py index 34247d5779..f98a300b09 100644 --- a/pytensor/link/numba/dispatch/extra_ops.py +++ b/pytensor/link/numba/dispatch/extra_ops.py @@ -19,7 +19,7 @@ Unique, UnravelIndex, ) -from aesara.raise_op import CheckAndRaise +from pytensor.raise_op import CheckAndRaise @numba_funcify.register(Bartlett) diff --git a/tests/link/numba/test_cython_support.py b/tests/link/numba/test_cython_support.py index c119b2fb62..b96a22098f 100644 --- a/tests/link/numba/test_cython_support.py +++ b/tests/link/numba/test_cython_support.py @@ -3,7 +3,7 @@ import scipy.special.cython_special from numba.types import float32, float64, int32, int64 -from aesara.link.numba.dispatch.cython_support import Signature, wrap_cython_function +from pytensor.link.numba.dispatch.cython_support import Signature, wrap_cython_function @pytest.mark.parametrize( From 004281adac604f294e39f6e7298b1a97902688e9 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 29 Nov 2022 23:40:12 -0600 Subject: [PATCH 14/17] Some small formatting and style changes --- pytensor/link/numba/dispatch/basic.py | 3 +-- pytensor/link/numba/dispatch/extra_ops.py | 6 +++++- pytensor/link/numba/dispatch/scalar.py | 5 ++++- pytensor/sparse/sandbox/sp.py | 4 +++- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 8cbfea5658..883d65930b 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -49,8 +49,7 @@ def numba_njit(*args, **kwargs): kwargs = kwargs.copy() - if "cache" not in kwargs: - kwargs["cache"] = config.numba__cache + kwargs.setdefault("cache", config.numba__cache) if len(args) > 0 and callable(args[0]): return numba.njit(*args[1:], **kwargs)(args[0]) diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py index f98a300b09..9871584454 100644 --- a/pytensor/link/numba/dispatch/extra_ops.py +++ b/pytensor/link/numba/dispatch/extra_ops.py @@ -7,6 +7,7 @@ from pytensor import config from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify +from pytensor.raise_op import CheckAndRaise from pytensor.tensor.extra_ops import ( Bartlett, BroadcastTo, @@ -19,7 +20,6 @@ Unique, UnravelIndex, ) -from pytensor.raise_op import CheckAndRaise @numba_funcify.register(Bartlett) @@ -48,11 +48,13 @@ def numba_funcify_CumOp(op, node, **kwargs): if mode == "add": if ndim == 1: + @numba_basic.numba_njit(fastmath=config.numba__fastmath) def cumop(x): return np.cumsum(x) else: + @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) def cumop(x): out_dtype = x.dtype @@ -70,11 +72,13 @@ def cumop(x): else: if ndim == 1: + @numba_basic.numba_njit(fastmath=config.numba__fastmath) def cumop(x): return np.cumprod(x) else: + @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) def cumop(x): out_dtype = x.dtype diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index 8679b6fb8c..d6c68d3208 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -144,7 +144,10 @@ def {scalar_op_fn_name}({', '.join(input_names)}): signature = create_numba_signature(node, force_scalar=True) return numba_basic.numba_njit( - signature, inline="always", fastmath=config.numba__fastmath, cache=False, + signature, + inline="always", + fastmath=config.numba__fastmath, + cache=False, )(scalar_op_fn) diff --git a/pytensor/sparse/sandbox/sp.py b/pytensor/sparse/sandbox/sp.py index 44f8126b9b..33b7de4665 100644 --- a/pytensor/sparse/sandbox/sp.py +++ b/pytensor/sparse/sandbox/sp.py @@ -182,7 +182,9 @@ def evaluate(inshp, kshp, strides=(1, 1), nkern=1, mode="valid", ws=True): # taking into account multiple # input features col = int( - iy * inshp[2] + ix + fmapi * np.prod(inshp[1:], dtype=int) + iy * inshp[2] + + ix + + fmapi * np.prod(inshp[1:], dtype=int) ) # convert oy,ox values to output From 4a4706290e3cf3c23b33997323b3dba4a6291483 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 30 Nov 2022 12:28:50 -0600 Subject: [PATCH 15/17] Fix None in slice for numba boxing --- pytensor/link/numba/dispatch/basic.py | 33 +++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 883d65930b..fb63706139 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -1,4 +1,5 @@ import operator +import sys import warnings from contextlib import contextmanager from functools import singledispatch @@ -10,7 +11,7 @@ import numpy as np import scipy import scipy.special -from llvmlite.ir import Type as llvm_Type +from llvmlite import ir from numba import types from numba.core.errors import TypingError from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 @@ -131,7 +132,7 @@ def create_numba_signature( def slice_new(self, start, stop, step): - fnty = llvm_Type.function(self.pyobj, [self.pyobj, self.pyobj, self.pyobj]) + fnty = ir.FunctionType(self.pyobj, [self.pyobj, self.pyobj, self.pyobj]) fn = self._get_function(fnty, name="PySlice_New") return self.builder.call(fn, [start, stop, step]) @@ -150,11 +151,33 @@ def box_slice(typ, val, c): This makes it possible to return an Numba's internal representation of a ``slice`` object as a proper ``slice`` to Python. """ + start = c.builder.extract_value(val, 0) + stop = c.builder.extract_value(val, 1) + + none_val = ir.Constant(ir.IntType(64), sys.maxsize) + + start_is_none = c.builder.icmp_signed("==", start, none_val) + start = c.builder.select( + start_is_none, + c.pyapi.get_null_object(), + c.box(types.int64, start), + ) + + stop_is_none = c.builder.icmp_signed("==", stop, none_val) + stop = c.builder.select( + stop_is_none, + c.pyapi.get_null_object(), + c.box(types.int64, stop), + ) - start = c.box(types.int64, c.builder.extract_value(val, 0)) - stop = c.box(types.int64, c.builder.extract_value(val, 1)) if typ.has_step: - step = c.box(types.int64, c.builder.extract_value(val, 2)) + step = c.builder.extract_value(val, 2) + step_is_none = c.builder.icmp_signed("==", step, none_val) + step = c.builder.select( + step_is_none, + c.pyapi.get_null_object(), + c.box(types.int64, step), + ) else: step = c.pyapi.get_null_object() From 38b9aa5e6c660eae30b3fec4b681c016d1799391 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 30 Nov 2022 12:29:16 -0600 Subject: [PATCH 16/17] Fix tuple constructor in BaseCorrMM --- pytensor/tensor/nnet/corr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/nnet/corr.py b/pytensor/tensor/nnet/corr.py index c747fb9f7c..6051a03b9b 100644 --- a/pytensor/tensor/nnet/corr.py +++ b/pytensor/tensor/nnet/corr.py @@ -692,7 +692,7 @@ def make_node(self, img, kern): if kern.type.ndim != 4: raise TypeError("kern must be 4D tensor") - out_shape = tuple( + out_shape = ( 1 if img.type.shape[0] == 1 else None, 1 if kern.type.shape[0] == 1 else None, None, From 812994941282d7d282c7894ad954cf11c82353a4 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Fri, 2 Dec 2022 11:10:00 -0600 Subject: [PATCH 17/17] Use numpy function to normalize axis argument --- pytensor/link/numba/dispatch/elemwise.py | 26 +++++++----------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 2ae8a3cec6..569bd9b164 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -6,6 +6,7 @@ import numba import numpy as np +from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple from pytensor import config from pytensor.graph.basic import Apply @@ -164,18 +165,6 @@ def create_vectorize_func( return elemwise_fn -def normalize_axis(axis, ndim): - if axis is None: - return axis - - if axis < 0: - axis = ndim + axis - - if axis < 0 or axis >= ndim: - raise np.AxisError(ndim=ndim, axis=axis) - return axis - - def create_axis_reducer( scalar_op: Op, identity: Union[np.ndarray, Number], @@ -230,7 +219,7 @@ def careduce_axis(x): """ - axis = normalize_axis(axis, ndim) + axis = normalize_axis_index(axis, ndim) reduce_elemwise_fn_name = "careduce_axis" @@ -354,7 +343,7 @@ def careduce_maximum(input): if len(axes) == 1: return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype) - axes = [normalize_axis(axis, ndim) for axis in axes] + axes = normalize_axis_tuple(axes, ndim) careduce_fn_name = f"careduce_{scalar_op}" global_env = {} @@ -425,7 +414,7 @@ def jit_compile_reducer(node, fn, **kwds): def create_axis_apply_fn(fn, axis, ndim, dtype): - axis = normalize_axis(axis, ndim) + axis = normalize_axis_index(axis, ndim) reaxis_first = tuple(i for i in range(ndim) if i != axis) + (axis,) @@ -627,9 +616,8 @@ def numba_funcify_Softmax(op, node, **kwargs): x_dtype = numba.np.numpy_support.from_dtype(x_dtype) axis = op.axis - axis = normalize_axis(axis, x_at.ndim) - if axis is not None: + axis = normalize_axis_index(axis, x_at.ndim) reduce_max_py = create_axis_reducer( scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True ) @@ -666,8 +654,8 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs): sm_dtype = numba.np.numpy_support.from_dtype(sm_dtype) axis = op.axis - axis = normalize_axis(axis, sm_at.ndim) if axis is not None: + axis = normalize_axis_index(axis, sm_at.ndim) reduce_sum_py = create_axis_reducer( add_as, 0.0, axis, sm_at.ndim, sm_dtype, keepdims=True ) @@ -697,9 +685,9 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): x_dtype = x_at.type.numpy_dtype x_dtype = numba.np.numpy_support.from_dtype(x_dtype) axis = op.axis - axis = normalize_axis(axis, x_at.ndim) if axis is not None: + axis = normalize_axis_index(axis, x_at.ndim) reduce_max_py = create_axis_reducer( scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True )