From 27459b68dccca3a0487504e6c4d035d2156d06c1 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Mon, 28 Nov 2022 21:25:36 -0600 Subject: [PATCH 1/5] Make scalar Categorical sampling work in recent versions of Numba --- pytensor/link/numba/dispatch/random.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index 23a762715b..dff7fab1ba 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -5,7 +5,15 @@ import numpy as np from numba import _helperlib, types from numba.core import cgutils -from numba.extending import NativeValue, box, models, register_model, typeof_impl, unbox +from numba.extending import ( + NativeValue, + box, + models, + overload, + register_model, + typeof_impl, + unbox, +) from numpy.random import RandomState import pytensor.tensor.random.basic as aer @@ -78,6 +86,16 @@ def box_random_state(typ, val, c): return class_obj +@overload(np.random.uniform) +def uniform_empty_size(a, b, size): + if isinstance(size, types.Tuple) and size.count == 0: + + def uniform_no_size(a, b, size): + return np.random.uniform(a, b) + + return uniform_no_size + + @numba_typify.register(RandomState) def numba_typify_RandomState(state, **kwargs): # The numba_typify in this case is just an passthrough function @@ -321,7 +339,7 @@ def categorical_rv(rng, size, dtype, p): size_tpl = numba_ndarray.to_fixed_tuple(size, size_len) p = np.broadcast_to(p, size_tpl + p.shape[-1:]) - unif_samples = np.random.uniform(0, 1, size_tpl) + unif_samples = np.asarray(np.random.uniform(0, 1, size_tpl)) res = np.empty(size_tpl, dtype=out_dtype) for idx in np.ndindex(*size_tpl): From 809c228014c9ca395e83714fa61d64c9610d0e0e Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Mon, 28 Nov 2022 23:32:41 -0600 Subject: [PATCH 2/5] Change type_conversion_fn to const_conversion_fn in fgraph_to_python --- pytensor/link/jax/dispatch/basic.py | 2 +- pytensor/link/numba/dispatch/basic.py | 2 +- pytensor/link/utils.py | 10 +++++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pytensor/link/jax/dispatch/basic.py b/pytensor/link/jax/dispatch/basic.py index 0e0fbec20b..cc88b68287 100644 --- a/pytensor/link/jax/dispatch/basic.py +++ b/pytensor/link/jax/dispatch/basic.py @@ -49,7 +49,7 @@ def jax_funcify_FunctionGraph( return fgraph_to_python( fgraph, jax_funcify, - type_conversion_fn=jax_typify, + const_conversion_fn=jax_typify, fgraph_name=fgraph_name, **kwargs, ) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 3287831622..0c74cd7a80 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -442,7 +442,7 @@ def numba_funcify_FunctionGraph( return fgraph_to_python( fgraph, numba_funcify, - type_conversion_fn=numba_typify, + const_conversion_fn=numba_typify, fgraph_name=fgraph_name, **kwargs, ) diff --git a/pytensor/link/utils.py b/pytensor/link/utils.py index fd76e1278e..69503d04a6 100644 --- a/pytensor/link/utils.py +++ b/pytensor/link/utils.py @@ -678,7 +678,7 @@ def fgraph_to_python( fgraph: FunctionGraph, op_conversion_fn: Callable, *, - type_conversion_fn: Callable = lambda x, **kwargs: x, + const_conversion_fn: Callable = lambda x, **kwargs: x, order: Optional[List[Apply]] = None, storage_map: Optional["StorageMapType"] = None, fgraph_name: str = "fgraph_to_python", @@ -698,8 +698,8 @@ def fgraph_to_python( A callable used to convert nodes inside `fgraph` based on their `Op` types. It must have the signature ``(op: Op, node: Apply=None, storage_map: Dict[Variable, List[Optional[Any]]]=None, **kwargs)``. - type_conversion_fn - A callable used to convert the values in `storage_map`. It must have + const_conversion_fn + A callable used to convert the `Constant` values in `storage_map`. It must have the signature ``(value: Optional[Any], variable: Variable=None, storage: List[Optional[Any]]=None, **kwargs)``. order @@ -753,7 +753,7 @@ def fgraph_to_python( ) if input_storage[0] is not None or isinstance(i, Constant): # Constants need to be assigned locally and referenced - global_env[local_input_name] = type_conversion_fn( + global_env[local_input_name] = const_conversion_fn( input_storage[0], variable=i, storage=input_storage, **kwargs ) # TODO: We could attempt to use the storage arrays directly @@ -776,7 +776,7 @@ def fgraph_to_python( output_storage = storage_map.setdefault( out, [None if not isinstance(out, Constant) else out.data] ) - global_env[local_output_name] = type_conversion_fn( + global_env[local_output_name] = const_conversion_fn( output_storage[0], variable=out, storage=output_storage, From e6ac03dcf03501fa855b7bcfd902feaa26306eef Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Mon, 28 Nov 2022 23:29:48 -0600 Subject: [PATCH 3/5] Change numba_typify to numba_const_convert --- pytensor/link/numba/dispatch/__init__.py | 2 +- pytensor/link/numba/dispatch/basic.py | 5 +++-- pytensor/link/numba/dispatch/random.py | 10 +++++----- pytensor/link/numba/linker.py | 4 ++-- tests/link/numba/test_basic.py | 4 ++-- 5 files changed, 13 insertions(+), 12 deletions(-) diff --git a/pytensor/link/numba/dispatch/__init__.py b/pytensor/link/numba/dispatch/__init__.py index c7cb2632a1..e0fb2b2fda 100644 --- a/pytensor/link/numba/dispatch/__init__.py +++ b/pytensor/link/numba/dispatch/__init__.py @@ -1,5 +1,5 @@ # isort: off -from pytensor.link.numba.dispatch.basic import numba_funcify, numba_typify +from pytensor.link.numba.dispatch.basic import numba_funcify, numba_const_convert # Load dispatch specializations import pytensor.link.numba.dispatch.scalar diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 0c74cd7a80..63e8c2227d 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -345,7 +345,8 @@ def use_optimized_cheap_pass(*args, **kwargs): @singledispatch -def numba_typify(data, dtype=None, **kwargs): +def numba_const_convert(data, dtype=None, **kwargs): + """Create a Numba compatible constant from an PyTensor `Constant`.""" return data @@ -442,7 +443,7 @@ def numba_funcify_FunctionGraph( return fgraph_to_python( fgraph, numba_funcify, - const_conversion_fn=numba_typify, + const_conversion_fn=numba_const_convert, fgraph_name=fgraph_name, **kwargs, ) diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index dff7fab1ba..5dbaad3f8d 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -20,7 +20,7 @@ from pytensor.graph.basic import Apply from pytensor.graph.op import Op from pytensor.link.numba.dispatch import basic as numba_basic -from pytensor.link.numba.dispatch.basic import numba_funcify, numba_typify +from pytensor.link.numba.dispatch.basic import numba_const_convert, numba_funcify from pytensor.link.utils import ( compile_function_src, get_name_for_object, @@ -96,11 +96,11 @@ def uniform_no_size(a, b, size): return uniform_no_size -@numba_typify.register(RandomState) -def numba_typify_RandomState(state, **kwargs): - # The numba_typify in this case is just an passthrough function +@numba_const_convert.register(RandomState) +def numba_const_convert_RandomState(state, **kwargs): + # The `numba_const_convert` in this case is just a passthrough function # that synchronizes Numba's internal random state with the current - # RandomState object + # `RandomState` object. ints, index = state.get_state()[1:3] ptr = _helperlib.rnd_get_np_state_ptr() _helperlib.rnd_set_state(ptr, (index, [int(x) for x in ints])) diff --git a/pytensor/link/numba/linker.py b/pytensor/link/numba/linker.py index 7cddedbc58..1dbe29d299 100644 --- a/pytensor/link/numba/linker.py +++ b/pytensor/link/numba/linker.py @@ -35,13 +35,13 @@ def jit_compile(self, fn): def create_thunk_inputs(self, storage_map): from numpy.random import RandomState - from pytensor.link.numba.dispatch import numba_typify + from pytensor.link.numba.dispatch import numba_const_convert thunk_inputs = [] for n in self.fgraph.inputs: sinput = storage_map[n] if isinstance(sinput[0], RandomState): - new_value = numba_typify( + new_value = numba_const_convert( sinput[0], dtype=getattr(sinput[0], "dtype", None) ) # We need to remove the reference-based connection to the diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 887ec63d9b..de351236d1 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -24,7 +24,7 @@ from pytensor.graph.type import Type from pytensor.ifelse import ifelse from pytensor.link.numba.dispatch import basic as numba_basic -from pytensor.link.numba.dispatch import numba_typify +from pytensor.link.numba.dispatch import numba_const_convert from pytensor.link.numba.linker import NumbaLinker from pytensor.raise_op import assert_op from pytensor.tensor import blas @@ -321,7 +321,7 @@ def test_create_numba_signature(v, expected, force_scalar): [ ( np.random.RandomState(1), - numba_typify, + numba_const_convert, lambda x, y: np.all(x.get_state()[1] == y.get_state()[1]), ) ], From b1b97ee91dc931d3db640ef2306ff562a1d7ab17 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Thu, 1 Dec 2022 01:03:49 +0100 Subject: [PATCH 4/5] Make get_numba_type dispatch on Type --- pytensor/link/numba/dispatch/basic.py | 67 ++++++++++++++------------- tests/link/numba/test_basic.py | 30 ++++++------ 2 files changed, 48 insertions(+), 49 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 63e8c2227d..7c7fec85dc 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -31,6 +31,7 @@ ) from pytensor.scalar.basic import ScalarType from pytensor.scalar.math import Softplus +from pytensor.sparse.type import SparseTensorType from pytensor.tensor.blas import BatchedDot from pytensor.tensor.math import Dot from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape @@ -65,14 +66,33 @@ def numba_vectorize(*args, **kwargs): return numba.vectorize(*args, cache=config.numba__cache, **kwargs) -def get_numba_type( - pytensor_type: Type, +@singledispatch +def get_numba_type(pytensor_type: Type, **kwargs) -> numba.types.Type: + r"""Create a Numba type object for a :class:`Type`.""" + return numba.types.pyobject + + +@get_numba_type.register(SparseTensorType) +def get_numba_type_SparseType(pytensor_type, **kwargs): + # This is needed to differentiate `SparseTensorType` from `TensorType` + return numba.types.pyobject + + +@get_numba_type.register(ScalarType) +def get_numba_type_ScalarType(pytensor_type, **kwargs): + dtype = np.dtype(pytensor_type.dtype) + numba_dtype = numba.from_dtype(dtype) + return numba_dtype + + +@get_numba_type.register(TensorType) +def get_numba_type_TensorType( + pytensor_type, layout: str = "A", force_scalar: bool = False, reduce_to_scalar: bool = False, -) -> numba.types.Type: - r"""Create a Numba type object for a :class:`Type`. - +): + r""" Parameters ---------- pytensor_type @@ -84,44 +104,27 @@ def get_numba_type( reduce_to_scalar Return Numba scalars for zero dimensional :class:`TensorType`\s. """ - - if isinstance(pytensor_type, TensorType): - dtype = pytensor_type.numpy_dtype - numba_dtype = numba.from_dtype(dtype) - if force_scalar or ( - reduce_to_scalar and getattr(pytensor_type, "ndim", None) == 0 - ): - return numba_dtype - return numba.types.Array(numba_dtype, pytensor_type.ndim, layout) - elif isinstance(pytensor_type, ScalarType): - dtype = np.dtype(pytensor_type.dtype) - numba_dtype = numba.from_dtype(dtype) + dtype = pytensor_type.numpy_dtype + numba_dtype = numba.from_dtype(dtype) + if force_scalar or (reduce_to_scalar and getattr(pytensor_type, "ndim", None) == 0): return numba_dtype - else: - raise NotImplementedError(f"Numba type not implemented for {pytensor_type}") + return numba.types.Array(numba_dtype, pytensor_type.ndim, layout) def create_numba_signature( - node_or_fgraph: Union[FunctionGraph, Apply], - force_scalar: bool = False, - reduce_to_scalar: bool = False, + node_or_fgraph: Union[FunctionGraph, Apply], **kwargs ) -> numba.types.Type: """Create a Numba type for the signature of an `Apply` node or `FunctionGraph`.""" input_types = [] for inp in node_or_fgraph.inputs: - input_types.append( - get_numba_type( - inp.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar - ) - ) + input_types.append(get_numba_type(inp.type, **kwargs)) output_types = [] for out in node_or_fgraph.outputs: - output_types.append( - get_numba_type( - out.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar - ) - ) + output_types.append(get_numba_type(out.type, **kwargs)) + + if isinstance(node_or_fgraph, FunctionGraph): + return numba.types.Tuple(output_types)(*input_types) if len(output_types) > 1: return numba.types.Tuple(output_types)(*input_types) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index de351236d1..fd55be5222 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -27,6 +27,7 @@ from pytensor.link.numba.dispatch import numba_const_convert from pytensor.link.numba.linker import NumbaLinker from pytensor.raise_op import assert_op +from pytensor.sparse.type import SparseTensorType from pytensor.tensor import blas from pytensor.tensor import subtensor as at_subtensor from pytensor.tensor.elemwise import Elemwise @@ -252,26 +253,21 @@ def assert_fn(x, y): @pytest.mark.parametrize( - "v, expected, force_scalar, not_implemented", + "v, expected, force_scalar", [ - (MyType(), None, False, True), - (aes.float32, numba.types.float32, False, False), - (at.fscalar, numba.types.Array(numba.types.float32, 0, "A"), False, False), - (at.fscalar, numba.types.float32, True, False), - (at.lvector, numba.types.int64[:], False, False), - (at.dmatrix, numba.types.float64[:, :], False, False), - (at.dmatrix, numba.types.float64, True, False), + (MyType(), numba.types.pyobject, False), + (SparseTensorType("csc", dtype=np.float64), numba.types.pyobject, False), + (aes.float32, numba.types.float32, False), + (at.fscalar, numba.types.Array(numba.types.float32, 0, "A"), False), + (at.fscalar, numba.types.float32, True), + (at.lvector, numba.types.int64[:], False), + (at.dmatrix, numba.types.float64[:, :], False), + (at.dmatrix, numba.types.float64, True), ], ) -def test_get_numba_type(v, expected, force_scalar, not_implemented): - cm = ( - contextlib.suppress() - if not not_implemented - else pytest.raises(NotImplementedError) - ) - with cm: - res = numba_basic.get_numba_type(v, force_scalar=force_scalar) - assert res == expected +def test_get_numba_type(v, expected, force_scalar): + res = numba_basic.get_numba_type(v, force_scalar=force_scalar) + assert res == expected @pytest.mark.parametrize( From dc5e943535265c914b34f91f8cf2853179e027ec Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Fri, 16 Dec 2022 15:44:18 +0100 Subject: [PATCH 5/5] Separate interface and dispatch of numba_funcify --- doc/extending/creating_a_numba_jax_op.rst | 4 +- pytensor/link/numba/dispatch/__init__.py | 6 +- pytensor/link/numba/dispatch/basic.py | 98 +++++++++++++------- pytensor/link/numba/dispatch/elemwise.py | 15 +-- pytensor/link/numba/dispatch/extra_ops.py | 24 ++--- pytensor/link/numba/dispatch/nlinalg.py | 18 ++-- pytensor/link/numba/dispatch/random.py | 60 ++++++------ pytensor/link/numba/dispatch/scalar.py | 39 ++++---- pytensor/link/numba/dispatch/scan.py | 3 +- pytensor/link/numba/dispatch/tensor_basic.py | 26 +++--- pytensor/link/numba/linker.py | 4 +- 11 files changed, 169 insertions(+), 128 deletions(-) diff --git a/doc/extending/creating_a_numba_jax_op.rst b/doc/extending/creating_a_numba_jax_op.rst index f6e50556bf..36858f4f8a 100644 --- a/doc/extending/creating_a_numba_jax_op.rst +++ b/doc/extending/creating_a_numba_jax_op.rst @@ -83,7 +83,7 @@ Here's an example for :class:`IfElse`: return res if n_outs > 1 else res[0] -Step 3: Register the function with the `jax_funcify` dispatcher +Step 3: Register the function with the `_jax_funcify` dispatcher --------------------------------------------------------------- With the PyTensor `Op` replicated in JAX, we’ll need to register the @@ -91,7 +91,7 @@ function with the PyTensor JAX `Linker`. This is done through the use of `singledispatch`. If you don't know how `singledispatch` works, see the `Python documentation `_. -The relevant dispatch functions created by `singledispatch` are :func:`pytensor.link.numba.dispatch.numba_funcify` and +The relevant dispatch functions created by `singledispatch` are :func:`pytensor.link.numba.dispatch.basic._numba_funcify` and :func:`pytensor.link.jax.dispatch.jax_funcify`. Here’s an example for the `Eye`\ `Op`: diff --git a/pytensor/link/numba/dispatch/__init__.py b/pytensor/link/numba/dispatch/__init__.py index e0fb2b2fda..9d47d25385 100644 --- a/pytensor/link/numba/dispatch/__init__.py +++ b/pytensor/link/numba/dispatch/__init__.py @@ -1,5 +1,9 @@ # isort: off -from pytensor.link.numba.dispatch.basic import numba_funcify, numba_const_convert +from pytensor.link.numba.dispatch.basic import ( + numba_funcify, + numba_const_convert, + numba_njit, +) # Load dispatch specializations import pytensor.link.numba.dispatch.scalar diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 7c7fec85dc..25b37f312b 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -4,7 +4,7 @@ from contextlib import contextmanager from functools import singledispatch from textwrap import dedent -from typing import Union +from typing import TYPE_CHECKING, Callable, Optional, Union, cast import numba import numba.np.unsafe.ndarray as numba_ndarray @@ -22,6 +22,7 @@ from pytensor.compile.ops import DeepCopyOp from pytensor.graph.basic import Apply, NoParams from pytensor.graph.fg import FunctionGraph +from pytensor.graph.op import Op from pytensor.graph.type import Type from pytensor.ifelse import IfElse from pytensor.link.utils import ( @@ -48,6 +49,10 @@ from pytensor.tensor.type_other import MakeSlice, NoneConst +if TYPE_CHECKING: + from pytensor.graph.op import StorageMapType + + def numba_njit(*args, **kwargs): kwargs = kwargs.copy() @@ -353,8 +358,43 @@ def numba_const_convert(data, dtype=None, **kwargs): return data -def generate_fallback_impl(op, node=None, storage_map=None, **kwargs): - """Create a Numba compatible function from an Aesara `Op`.""" +def numba_funcify(obj, node=None, storage_map=None, **kwargs) -> Callable: + """Convert `obj` to a Numba-JITable object.""" + return _numba_funcify(obj, node=node, storage_map=storage_map, **kwargs) + + +@singledispatch +def _numba_funcify( + obj, + node: Optional[Apply] = None, + storage_map: Optional["StorageMapType"] = None, + **kwargs, +) -> Callable: + r"""Dispatch on PyTensor object types to perform Numba conversions. + + Arguments + --------- + obj + The object used to determine the appropriate conversion function based + on its type. This is generally an `Op` instance, but `FunctionGraph`\s + are also supported. + node + When `obj` is an `Op`, this value should be the corresponding `Apply` node. + storage_map + A storage map with, for example, the constant and `SharedVariable` values + of the graph being converted. + + Returns + ------- + A `Callable` that can be JIT-compiled in Numba using `numba.jit`. + + """ + raise NotImplementedError(f"Numba funcify for obj {obj} not implemented") + + +@_numba_funcify.register(Op) +def numba_funcify_perform(op, node, storage_map=None, **kwargs) -> Callable: + """Create a Numba compatible function from an PyTensor `Op.perform`.""" warnings.warn( f"Numba will use object mode to run {op}'s perform method", @@ -405,16 +445,10 @@ def perform(*inputs): ret = py_perform_return(inputs) return ret - return perform - - -@singledispatch -def numba_funcify(op, node=None, storage_map=None, **kwargs): - """Generate a numba function for a given op and apply node.""" - return generate_fallback_impl(op, node, storage_map, **kwargs) + return cast(Callable, perform) -@numba_funcify.register(OpFromGraph) +@_numba_funcify.register(OpFromGraph) def numba_funcify_OpFromGraph(op, node=None, **kwargs): _ = kwargs.pop("storage_map", None) @@ -436,7 +470,7 @@ def opfromgraph(*inputs): return opfromgraph -@numba_funcify.register(FunctionGraph) +@_numba_funcify.register(FunctionGraph) def numba_funcify_FunctionGraph( fgraph, node=None, @@ -544,8 +578,8 @@ def {fn_name}({", ".join(input_names)}): return subtensor_def_src -@numba_funcify.register(Subtensor) -@numba_funcify.register(AdvancedSubtensor1) +@_numba_funcify.register(Subtensor) +@_numba_funcify.register(AdvancedSubtensor1) def numba_funcify_Subtensor(op, node, **kwargs): subtensor_def_src = create_index_func( @@ -561,7 +595,7 @@ def numba_funcify_Subtensor(op, node, **kwargs): return numba_njit(subtensor_fn, boundscheck=True) -@numba_funcify.register(IncSubtensor) +@_numba_funcify.register(IncSubtensor) def numba_funcify_IncSubtensor(op, node, **kwargs): incsubtensor_def_src = create_index_func( @@ -577,7 +611,7 @@ def numba_funcify_IncSubtensor(op, node, **kwargs): return numba_njit(incsubtensor_fn, boundscheck=True) -@numba_funcify.register(AdvancedIncSubtensor1) +@_numba_funcify.register(AdvancedIncSubtensor1) def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): inplace = op.inplace set_instead_of_inc = op.set_instead_of_inc @@ -610,7 +644,7 @@ def advancedincsubtensor1(x, vals, idxs): return advancedincsubtensor1 -@numba_funcify.register(DeepCopyOp) +@_numba_funcify.register(DeepCopyOp) def numba_funcify_DeepCopyOp(op, node, **kwargs): # Scalars are apparently returned as actual Python scalar types and not @@ -632,8 +666,8 @@ def deepcopyop(x): return deepcopyop -@numba_funcify.register(MakeSlice) -def numba_funcify_MakeSlice(op, **kwargs): +@_numba_funcify.register(MakeSlice) +def numba_funcify_MakeSlice(op, node, **kwargs): @numba_njit def makeslice(*x): return slice(*x) @@ -641,8 +675,8 @@ def makeslice(*x): return makeslice -@numba_funcify.register(Shape) -def numba_funcify_Shape(op, **kwargs): +@_numba_funcify.register(Shape) +def numba_funcify_Shape(op, node, **kwargs): @numba_njit(inline="always") def shape(x): return np.asarray(np.shape(x)) @@ -650,8 +684,8 @@ def shape(x): return shape -@numba_funcify.register(Shape_i) -def numba_funcify_Shape_i(op, **kwargs): +@_numba_funcify.register(Shape_i) +def numba_funcify_Shape_i(op, node, **kwargs): i = op.i @numba_njit(inline="always") @@ -681,8 +715,8 @@ def codegen(context, builder, signature, args): return sig, codegen -@numba_funcify.register(Reshape) -def numba_funcify_Reshape(op, **kwargs): +@_numba_funcify.register(Reshape) +def numba_funcify_Reshape(op, node, **kwargs): ndim = op.ndim if ndim == 0: @@ -704,7 +738,7 @@ def reshape(x, shape): return reshape -@numba_funcify.register(SpecifyShape) +@_numba_funcify.register(SpecifyShape) def numba_funcify_SpecifyShape(op, node, **kwargs): shape_inputs = node.inputs[1:] shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))] @@ -751,7 +785,7 @@ def inputs_cast(x): return inputs_cast -@numba_funcify.register(Dot) +@_numba_funcify.register(Dot) def numba_funcify_Dot(op, node, **kwargs): # Numba's `np.dot` does not support integer dtypes, so we need to cast to # float. @@ -766,7 +800,7 @@ def dot(x, y): return dot -@numba_funcify.register(Softplus) +@_numba_funcify.register(Softplus) def numba_funcify_Softplus(op, node, **kwargs): x_dtype = np.dtype(node.inputs[0].dtype) @@ -785,7 +819,7 @@ def softplus(x): return softplus -@numba_funcify.register(Cholesky) +@_numba_funcify.register(Cholesky) def numba_funcify_Cholesky(op, node, **kwargs): lower = op.lower @@ -821,7 +855,7 @@ def cholesky(a): return cholesky -@numba_funcify.register(Solve) +@_numba_funcify.register(Solve) def numba_funcify_Solve(op, node, **kwargs): assume_a = op.assume_a @@ -868,7 +902,7 @@ def solve(a, b): return solve -@numba_funcify.register(BatchedDot) +@_numba_funcify.register(BatchedDot) def numba_funcify_BatchedDot(op, node, **kwargs): dtype = node.outputs[0].type.numpy_dtype @@ -889,7 +923,7 @@ def batched_dot(x, y): # optimizations are apparently already performed by Numba -@numba_funcify.register(IfElse) +@_numba_funcify.register(IfElse) def numba_funcify_IfElse(op, **kwargs): n_outs = op.n_outs diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 0595191da0..e884d138f9 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -13,6 +13,7 @@ from pytensor.graph.op import Op from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( + _numba_funcify, create_numba_signature, create_tuple_creator, numba_funcify, @@ -431,7 +432,7 @@ def axis_apply_fn(x): return axis_apply_fn -@numba_funcify.register(Elemwise) +@_numba_funcify.register(Elemwise) def numba_funcify_Elemwise(op, node, **kwargs): # Creating a new scalar node is more involved and unnecessary # if the scalar_op is composite, as the fgraph already contains @@ -492,7 +493,7 @@ def {inplace_elemwise_fn_name}({input_signature_str}): return elemwise_fn -@numba_funcify.register(CAReduce) +@_numba_funcify.register(CAReduce) def numba_funcify_CAReduce(op, node, **kwargs): axes = op.axis if axes is None: @@ -530,7 +531,7 @@ def numba_funcify_CAReduce(op, node, **kwargs): return careduce_fn -@numba_funcify.register(DimShuffle) +@_numba_funcify.register(DimShuffle) def numba_funcify_DimShuffle(op, node, **kwargs): shuffle = tuple(op.shuffle) transposition = tuple(op.transposition) @@ -628,7 +629,7 @@ def dimshuffle(x): return dimshuffle -@numba_funcify.register(Softmax) +@_numba_funcify.register(Softmax) def numba_funcify_Softmax(op, node, **kwargs): x_at = node.inputs[0] @@ -666,7 +667,7 @@ def softmax_py_fn(x): return softmax -@numba_funcify.register(SoftmaxGrad) +@_numba_funcify.register(SoftmaxGrad) def numba_funcify_SoftmaxGrad(op, node, **kwargs): sm_at = node.inputs[1] @@ -698,7 +699,7 @@ def softmax_grad_py_fn(dy, sm): return softmax_grad -@numba_funcify.register(LogSoftmax) +@_numba_funcify.register(LogSoftmax) def numba_funcify_LogSoftmax(op, node, **kwargs): x_at = node.inputs[0] @@ -733,7 +734,7 @@ def log_softmax_py_fn(x): return log_softmax -@numba_funcify.register(MaxAndArgmax) +@_numba_funcify.register(MaxAndArgmax) def numba_funcify_MaxAndArgmax(op, node, **kwargs): axis = op.axis x_at = node.inputs[0] diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py index 9871584454..78422273a7 100644 --- a/pytensor/link/numba/dispatch/extra_ops.py +++ b/pytensor/link/numba/dispatch/extra_ops.py @@ -6,7 +6,7 @@ from pytensor import config from pytensor.link.numba.dispatch import basic as numba_basic -from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify +from pytensor.link.numba.dispatch.basic import _numba_funcify, get_numba_type from pytensor.raise_op import CheckAndRaise from pytensor.tensor.extra_ops import ( Bartlett, @@ -22,7 +22,7 @@ ) -@numba_funcify.register(Bartlett) +@_numba_funcify.register(Bartlett) def numba_funcify_Bartlett(op, **kwargs): @numba_basic.numba_njit(inline="always") def bartlett(x): @@ -31,7 +31,7 @@ def bartlett(x): return bartlett -@numba_funcify.register(CumOp) +@_numba_funcify.register(CumOp) def numba_funcify_CumOp(op, node, **kwargs): axis = op.axis mode = op.mode @@ -97,7 +97,7 @@ def cumop(x): return cumop -@numba_funcify.register(FillDiagonal) +@_numba_funcify.register(FillDiagonal) def numba_funcify_FillDiagonal(op, **kwargs): @numba_basic.numba_njit def filldiagonal(a, val): @@ -107,7 +107,7 @@ def filldiagonal(a, val): return filldiagonal -@numba_funcify.register(FillDiagonalOffset) +@_numba_funcify.register(FillDiagonalOffset) def numba_funcify_FillDiagonalOffset(op, node, **kwargs): @numba_basic.numba_njit def filldiagonaloffset(a, val, offset): @@ -132,7 +132,7 @@ def filldiagonaloffset(a, val, offset): return filldiagonaloffset -@numba_funcify.register(RavelMultiIndex) +@_numba_funcify.register(RavelMultiIndex) def numba_funcify_RavelMultiIndex(op, node, **kwargs): mode = op.mode @@ -197,7 +197,7 @@ def ravelmultiindex(*inp): return ravelmultiindex -@numba_funcify.register(Repeat) +@_numba_funcify.register(Repeat) def numba_funcify_Repeat(op, node, **kwargs): axis = op.axis @@ -242,7 +242,7 @@ def repeatop(x, repeats): return repeatop -@numba_funcify.register(Unique) +@_numba_funcify.register(Unique) def numba_funcify_Unique(op, node, **kwargs): axis = op.axis @@ -288,7 +288,7 @@ def unique(x): return unique -@numba_funcify.register(UnravelIndex) +@_numba_funcify.register(UnravelIndex) def numba_funcify_UnravelIndex(op, node, **kwargs): order = op.order @@ -323,7 +323,7 @@ def unravelindex(arr, shape): return unravelindex -@numba_funcify.register(SearchsortedOp) +@_numba_funcify.register(SearchsortedOp) def numba_funcify_Searchsorted(op, node, **kwargs): side = op.side @@ -357,7 +357,7 @@ def searchsorted(a, v): return searchsorted -@numba_funcify.register(BroadcastTo) +@_numba_funcify.register(BroadcastTo) def numba_funcify_BroadcastTo(op, node, **kwargs): create_zeros_tuple = numba_basic.create_tuple_creator( @@ -380,7 +380,7 @@ def broadcast_to(x, *shape): return broadcast_to -@numba_funcify.register(CheckAndRaise) +@_numba_funcify.register(CheckAndRaise) def numba_funcify_CheckAndRaise(op, node, **kwargs): error = op.exc_type msg = op.msg diff --git a/pytensor/link/numba/dispatch/nlinalg.py b/pytensor/link/numba/dispatch/nlinalg.py index 21fa34e1bb..ec24f94782 100644 --- a/pytensor/link/numba/dispatch/nlinalg.py +++ b/pytensor/link/numba/dispatch/nlinalg.py @@ -5,9 +5,9 @@ from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( + _numba_funcify, get_numba_type, int_to_float_fn, - numba_funcify, ) from pytensor.tensor.nlinalg import ( SVD, @@ -21,7 +21,7 @@ ) -@numba_funcify.register(SVD) +@_numba_funcify.register(SVD) def numba_funcify_SVD(op, node, **kwargs): full_matrices = op.full_matrices compute_uv = op.compute_uv @@ -45,7 +45,7 @@ def svd(x): return svd -@numba_funcify.register(Det) +@_numba_funcify.register(Det) def numba_funcify_Det(op, node, **kwargs): out_dtype = node.outputs[0].type.numpy_dtype @@ -58,7 +58,7 @@ def det(x): return det -@numba_funcify.register(Eig) +@_numba_funcify.register(Eig) def numba_funcify_Eig(op, node, **kwargs): out_dtype_1 = node.outputs[0].type.numpy_dtype @@ -74,7 +74,7 @@ def eig(x): return eig -@numba_funcify.register(Eigh) +@_numba_funcify.register(Eigh) def numba_funcify_Eigh(op, node, **kwargs): uplo = op.UPLO @@ -109,7 +109,7 @@ def eigh(x): return eigh -@numba_funcify.register(Inv) +@_numba_funcify.register(Inv) def numba_funcify_Inv(op, node, **kwargs): out_dtype = node.outputs[0].type.numpy_dtype @@ -122,7 +122,7 @@ def inv(x): return inv -@numba_funcify.register(MatrixInverse) +@_numba_funcify.register(MatrixInverse) def numba_funcify_MatrixInverse(op, node, **kwargs): out_dtype = node.outputs[0].type.numpy_dtype @@ -135,7 +135,7 @@ def matrix_inverse(x): return matrix_inverse -@numba_funcify.register(MatrixPinv) +@_numba_funcify.register(MatrixPinv) def numba_funcify_MatrixPinv(op, node, **kwargs): out_dtype = node.outputs[0].type.numpy_dtype @@ -148,7 +148,7 @@ def matrixpinv(x): return matrixpinv -@numba_funcify.register(QRFull) +@_numba_funcify.register(QRFull) def numba_funcify_QRFull(op, node, **kwargs): mode = op.mode diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index 5dbaad3f8d..80005f2350 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -20,7 +20,7 @@ from pytensor.graph.basic import Apply from pytensor.graph.op import Op from pytensor.link.numba.dispatch import basic as numba_basic -from pytensor.link.numba.dispatch.basic import numba_const_convert, numba_funcify +from pytensor.link.numba.dispatch.basic import _numba_funcify, numba_const_convert from pytensor.link.utils import ( compile_function_src, get_name_for_object, @@ -207,29 +207,29 @@ def {sized_fn_name}({random_fn_input_names}): return random_fn -@numba_funcify.register(aer.UniformRV) -@numba_funcify.register(aer.TriangularRV) -@numba_funcify.register(aer.BetaRV) -@numba_funcify.register(aer.NormalRV) -@numba_funcify.register(aer.LogNormalRV) -@numba_funcify.register(aer.GammaRV) -@numba_funcify.register(aer.ChiSquareRV) -@numba_funcify.register(aer.ParetoRV) -@numba_funcify.register(aer.GumbelRV) -@numba_funcify.register(aer.ExponentialRV) -@numba_funcify.register(aer.WeibullRV) -@numba_funcify.register(aer.LogisticRV) -@numba_funcify.register(aer.VonMisesRV) -@numba_funcify.register(aer.PoissonRV) -@numba_funcify.register(aer.GeometricRV) -@numba_funcify.register(aer.HyperGeometricRV) -@numba_funcify.register(aer.WaldRV) -@numba_funcify.register(aer.LaplaceRV) -@numba_funcify.register(aer.BinomialRV) -@numba_funcify.register(aer.MultinomialRV) -@numba_funcify.register(aer.RandIntRV) # only the first two arguments are supported -@numba_funcify.register(aer.ChoiceRV) # the `p` argument is not supported -@numba_funcify.register(aer.PermutationRV) +@_numba_funcify.register(aer.UniformRV) +@_numba_funcify.register(aer.TriangularRV) +@_numba_funcify.register(aer.BetaRV) +@_numba_funcify.register(aer.NormalRV) +@_numba_funcify.register(aer.LogNormalRV) +@_numba_funcify.register(aer.GammaRV) +@_numba_funcify.register(aer.ChiSquareRV) +@_numba_funcify.register(aer.ParetoRV) +@_numba_funcify.register(aer.GumbelRV) +@_numba_funcify.register(aer.ExponentialRV) +@_numba_funcify.register(aer.WeibullRV) +@_numba_funcify.register(aer.LogisticRV) +@_numba_funcify.register(aer.VonMisesRV) +@_numba_funcify.register(aer.PoissonRV) +@_numba_funcify.register(aer.GeometricRV) +@_numba_funcify.register(aer.HyperGeometricRV) +@_numba_funcify.register(aer.WaldRV) +@_numba_funcify.register(aer.LaplaceRV) +@_numba_funcify.register(aer.BinomialRV) +@_numba_funcify.register(aer.MultinomialRV) +@_numba_funcify.register(aer.RandIntRV) # only the first two arguments are supported +@_numba_funcify.register(aer.ChoiceRV) # the `p` argument is not supported +@_numba_funcify.register(aer.PermutationRV) def numba_funcify_RandomVariable(op, node, **kwargs): name = op.name np_random_func = getattr(np.random, name) @@ -285,12 +285,12 @@ def {np_random_fn_name}({np_input_names}): return make_numba_random_fn(node, np_random_fn) -@numba_funcify.register(aer.NegBinomialRV) +@_numba_funcify.register(aer.NegBinomialRV) def numba_funcify_NegBinomialRV(op, node, **kwargs): return make_numba_random_fn(node, np.random.negative_binomial) -@numba_funcify.register(aer.CauchyRV) +@_numba_funcify.register(aer.CauchyRV) def numba_funcify_CauchyRV(op, node, **kwargs): def body_fn(loc, scale): return f" return ({loc} + np.random.standard_cauchy()) / {scale}" @@ -298,7 +298,7 @@ def body_fn(loc, scale): return create_numba_random_fn(op, node, body_fn) -@numba_funcify.register(aer.HalfNormalRV) +@_numba_funcify.register(aer.HalfNormalRV) def numba_funcify_HalfNormalRV(op, node, **kwargs): def body_fn(a, b): return f" return {a} + {b} * abs(np.random.normal(0, 1))" @@ -306,7 +306,7 @@ def body_fn(a, b): return create_numba_random_fn(op, node, body_fn) -@numba_funcify.register(aer.BernoulliRV) +@_numba_funcify.register(aer.BernoulliRV) def numba_funcify_BernoulliRV(op, node, **kwargs): out_dtype = node.outputs[1].type.numpy_dtype @@ -326,7 +326,7 @@ def body_fn(a): ) -@numba_funcify.register(aer.CategoricalRV) +@_numba_funcify.register(aer.CategoricalRV) def numba_funcify_CategoricalRV(op, node, **kwargs): out_dtype = node.outputs[1].type.numpy_dtype size_len = int(get_vector_length(node.inputs[1])) @@ -350,7 +350,7 @@ def categorical_rv(rng, size, dtype, p): return categorical_rv -@numba_funcify.register(aer.DirichletRV) +@_numba_funcify.register(aer.DirichletRV) def numba_funcify_DirichletRV(op, node, **kwargs): out_dtype = node.outputs[1].type.numpy_dtype diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index d6c68d3208..d1e8314cde 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -8,9 +8,10 @@ from pytensor.graph.basic import Variable from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( + _numba_funcify, create_numba_signature, - generate_fallback_impl, numba_funcify, + numba_funcify_perform, ) from pytensor.link.numba.dispatch.cython_support import wrap_cython_function from pytensor.link.utils import ( @@ -33,7 +34,7 @@ from pytensor.scalar.math import Erf, Erfc, GammaLn, Log1mexp, Sigmoid -@numba_funcify.register(ScalarOp) +@_numba_funcify.register(ScalarOp) def numba_funcify_ScalarOp(op, node, **kwargs): # TODO: Do we need to cache these functions so that we don't end up # compiling the same Numba function over and over again? @@ -57,7 +58,7 @@ def numba_funcify_ScalarOp(op, node, **kwargs): input_inner_dtypes = None output_inner_dtype = None - # Cython functions might have an additonal argument + # Cython functions might have an additional argument has_pyx_skip_dispatch = False if scalar_func_path.startswith("scipy.special"): @@ -76,7 +77,7 @@ def numba_funcify_ScalarOp(op, node, **kwargs): # pass if scalar_func_numba is None: - scalar_func_numba = generate_fallback_impl(op, node, **kwargs) + scalar_func_numba = numba_funcify_perform(op, node, **kwargs) scalar_op_fn_name = get_name_for_object(scalar_func_numba) unique_names = unique_name_generator( @@ -151,7 +152,7 @@ def {scalar_op_fn_name}({', '.join(input_names)}): )(scalar_op_fn) -@numba_funcify.register(Switch) +@_numba_funcify.register(Switch) def numba_funcify_Switch(op, node, **kwargs): @numba_basic.numba_njit(inline="always") def switch(condition, x, y): @@ -179,7 +180,7 @@ def {binary_op_name}({input_signature}): return nary_fn -@numba_funcify.register(Add) +@_numba_funcify.register(Add) def numba_funcify_Add(op, node, **kwargs): signature = create_numba_signature(node, force_scalar=True) @@ -191,7 +192,7 @@ def numba_funcify_Add(op, node, **kwargs): )(nary_add_fn) -@numba_funcify.register(Mul) +@_numba_funcify.register(Mul) def numba_funcify_Mul(op, node, **kwargs): signature = create_numba_signature(node, force_scalar=True) @@ -203,7 +204,7 @@ def numba_funcify_Mul(op, node, **kwargs): )(nary_mul_fn) -@numba_funcify.register(Cast) +@_numba_funcify.register(Cast) def numba_funcify_Cast(op, node, **kwargs): dtype = np.dtype(op.o_type.dtype) @@ -215,8 +216,8 @@ def cast(x): return cast -@numba_funcify.register(Identity) -@numba_funcify.register(ViewOp) +@_numba_funcify.register(Identity) +@_numba_funcify.register(ViewOp) def numba_funcify_ViewOp(op, **kwargs): @numba_basic.numba_njit(inline="always") def viewop(x): @@ -225,7 +226,7 @@ def viewop(x): return viewop -@numba_funcify.register(Clip) +@_numba_funcify.register(Clip) def numba_funcify_Clip(op, **kwargs): @numba_basic.numba_njit def clip(_x, _min, _max): @@ -243,7 +244,7 @@ def clip(_x, _min, _max): return clip -@numba_funcify.register(Composite) +@_numba_funcify.register(Composite) def numba_funcify_Composite(op, node, **kwargs): signature = create_numba_signature(op.fgraph, force_scalar=True) @@ -255,7 +256,7 @@ def numba_funcify_Composite(op, node, **kwargs): return composite_fn -@numba_funcify.register(Second) +@_numba_funcify.register(Second) def numba_funcify_Second(op, node, **kwargs): @numba_basic.numba_njit(inline="always") def second(x, y): @@ -264,7 +265,7 @@ def second(x, y): return second -@numba_funcify.register(Reciprocal) +@_numba_funcify.register(Reciprocal) def numba_funcify_Reciprocal(op, node, **kwargs): @numba_basic.numba_njit(inline="always") def reciprocal(x): @@ -275,7 +276,7 @@ def reciprocal(x): return reciprocal -@numba_funcify.register(Sigmoid) +@_numba_funcify.register(Sigmoid) def numba_funcify_Sigmoid(op, node, **kwargs): @numba_basic.numba_njit(inline="always", fastmath=config.numba__fastmath) def sigmoid(x): @@ -284,7 +285,7 @@ def sigmoid(x): return sigmoid -@numba_funcify.register(GammaLn) +@_numba_funcify.register(GammaLn) def numba_funcify_GammaLn(op, node, **kwargs): @numba_basic.numba_njit(inline="always", fastmath=config.numba__fastmath) def gammaln(x): @@ -293,7 +294,7 @@ def gammaln(x): return gammaln -@numba_funcify.register(Log1mexp) +@_numba_funcify.register(Log1mexp) def numba_funcify_Log1mexp(op, node, **kwargs): @numba_basic.numba_njit(inline="always", fastmath=config.numba__fastmath) def logp1mexp(x): @@ -305,7 +306,7 @@ def logp1mexp(x): return logp1mexp -@numba_funcify.register(Erf) +@_numba_funcify.register(Erf) def numba_funcify_Erf(op, **kwargs): @numba_basic.numba_njit(inline="always", fastmath=config.numba__fastmath) def erf(x): @@ -314,7 +315,7 @@ def erf(x): return erf -@numba_funcify.register(Erfc) +@_numba_funcify.register(Erfc) def numba_funcify_Erfc(op, **kwargs): @numba_basic.numba_njit(inline="always", fastmath=config.numba__fastmath) def erfc(x): diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index c26cd9aa6c..4a2a88c51a 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -7,6 +7,7 @@ from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( + _numba_funcify, create_arg_string, create_tuple_string, numba_funcify, @@ -45,7 +46,7 @@ def range_arr(x): return range_arr -@numba_funcify.register(Scan) +@_numba_funcify.register(Scan) def numba_funcify_Scan(op, node, **kwargs): # Apply inner rewrites # TODO: Not sure this is the right place to do this, should we have a rewrite that diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index cef1ded67a..87b9ed07c5 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -3,7 +3,7 @@ import numpy as np from pytensor.link.numba.dispatch import basic as numba_basic -from pytensor.link.numba.dispatch.basic import create_tuple_string, numba_funcify +from pytensor.link.numba.dispatch.basic import _numba_funcify, create_tuple_string from pytensor.link.utils import compile_function_src, unique_name_generator from pytensor.tensor.basic import ( Alloc, @@ -21,7 +21,7 @@ from pytensor.tensor.shape import Unbroadcast -@numba_funcify.register(AllocEmpty) +@_numba_funcify.register(AllocEmpty) def numba_funcify_AllocEmpty(op, node, **kwargs): global_env = { @@ -59,7 +59,7 @@ def allocempty({", ".join(shape_var_names)}): return numba_basic.numba_njit(alloc_fn) -@numba_funcify.register(Alloc) +@_numba_funcify.register(Alloc) def numba_funcify_Alloc(op, node, **kwargs): global_env = {"np": np, "to_scalar": numba_basic.to_scalar} @@ -95,7 +95,7 @@ def alloc(val, {", ".join(shape_var_names)}): return numba_basic.numba_njit(alloc_fn) -@numba_funcify.register(AllocDiag) +@_numba_funcify.register(AllocDiag) def numba_funcify_AllocDiag(op, **kwargs): offset = op.offset @@ -106,7 +106,7 @@ def allocdiag(v): return allocdiag -@numba_funcify.register(ARange) +@_numba_funcify.register(ARange) def numba_funcify_ARange(op, **kwargs): dtype = np.dtype(op.dtype) @@ -122,7 +122,7 @@ def arange(start, stop, step): return arange -@numba_funcify.register(Join) +@_numba_funcify.register(Join) def numba_funcify_Join(op, **kwargs): view = op.view @@ -139,7 +139,7 @@ def join(axis, *tensors): return join -@numba_funcify.register(Split) +@_numba_funcify.register(Split) def numba_funcify_Split(op, **kwargs): @numba_basic.numba_njit def split(tensor, axis, indices): @@ -151,7 +151,7 @@ def split(tensor, axis, indices): return split -@numba_funcify.register(ExtractDiag) +@_numba_funcify.register(ExtractDiag) def numba_funcify_ExtractDiag(op, **kwargs): offset = op.offset # axis1 = op.axis1 @@ -164,7 +164,7 @@ def extract_diag(x): return extract_diag -@numba_funcify.register(Eye) +@_numba_funcify.register(Eye) def numba_funcify_Eye(op, **kwargs): dtype = np.dtype(op.dtype) @@ -180,7 +180,7 @@ def eye(N, M, k): return eye -@numba_funcify.register(MakeVector) +@_numba_funcify.register(MakeVector) def numba_funcify_MakeVector(op, node, **kwargs): dtype = np.dtype(op.dtype) @@ -208,7 +208,7 @@ def makevector({", ".join(input_names)}): return numba_basic.numba_njit(makevector_fn) -@numba_funcify.register(Unbroadcast) +@_numba_funcify.register(Unbroadcast) def numba_funcify_Unbroadcast(op, **kwargs): @numba_basic.numba_njit def unbroadcast(x): @@ -217,7 +217,7 @@ def unbroadcast(x): return unbroadcast -@numba_funcify.register(TensorFromScalar) +@_numba_funcify.register(TensorFromScalar) def numba_funcify_TensorFromScalar(op, **kwargs): @numba_basic.numba_njit(inline="always") def tensor_from_scalar(x): @@ -226,7 +226,7 @@ def tensor_from_scalar(x): return tensor_from_scalar -@numba_funcify.register(ScalarFromTensor) +@_numba_funcify.register(ScalarFromTensor) def numba_funcify_ScalarFromTensor(op, **kwargs): @numba_basic.numba_njit(inline="always") def scalar_from_tensor(x): diff --git a/pytensor/link/numba/linker.py b/pytensor/link/numba/linker.py index 1dbe29d299..851f9ac6aa 100644 --- a/pytensor/link/numba/linker.py +++ b/pytensor/link/numba/linker.py @@ -27,9 +27,9 @@ def fgraph_convert(self, fgraph, **kwargs): return numba_funcify(fgraph, **kwargs) def jit_compile(self, fn): - import numba + from pytensor.link.numba.dispatch import numba_njit - jitted_fn = numba.njit(fn) + jitted_fn = numba_njit(fn) return jitted_fn def create_thunk_inputs(self, storage_map):