Skip to content

Commit b1b97ee

Browse files
brandonwillardricardoV94
authored andcommitted
Make get_numba_type dispatch on Type
1 parent e6ac03d commit b1b97ee

File tree

2 files changed

+48
-49
lines changed

2 files changed

+48
-49
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
)
3232
from pytensor.scalar.basic import ScalarType
3333
from pytensor.scalar.math import Softplus
34+
from pytensor.sparse.type import SparseTensorType
3435
from pytensor.tensor.blas import BatchedDot
3536
from pytensor.tensor.math import Dot
3637
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
@@ -65,14 +66,33 @@ def numba_vectorize(*args, **kwargs):
6566
return numba.vectorize(*args, cache=config.numba__cache, **kwargs)
6667

6768

68-
def get_numba_type(
69-
pytensor_type: Type,
69+
@singledispatch
70+
def get_numba_type(pytensor_type: Type, **kwargs) -> numba.types.Type:
71+
r"""Create a Numba type object for a :class:`Type`."""
72+
return numba.types.pyobject
73+
74+
75+
@get_numba_type.register(SparseTensorType)
76+
def get_numba_type_SparseType(pytensor_type, **kwargs):
77+
# This is needed to differentiate `SparseTensorType` from `TensorType`
78+
return numba.types.pyobject
79+
80+
81+
@get_numba_type.register(ScalarType)
82+
def get_numba_type_ScalarType(pytensor_type, **kwargs):
83+
dtype = np.dtype(pytensor_type.dtype)
84+
numba_dtype = numba.from_dtype(dtype)
85+
return numba_dtype
86+
87+
88+
@get_numba_type.register(TensorType)
89+
def get_numba_type_TensorType(
90+
pytensor_type,
7091
layout: str = "A",
7192
force_scalar: bool = False,
7293
reduce_to_scalar: bool = False,
73-
) -> numba.types.Type:
74-
r"""Create a Numba type object for a :class:`Type`.
75-
94+
):
95+
r"""
7696
Parameters
7797
----------
7898
pytensor_type
@@ -84,44 +104,27 @@ def get_numba_type(
84104
reduce_to_scalar
85105
Return Numba scalars for zero dimensional :class:`TensorType`\s.
86106
"""
87-
88-
if isinstance(pytensor_type, TensorType):
89-
dtype = pytensor_type.numpy_dtype
90-
numba_dtype = numba.from_dtype(dtype)
91-
if force_scalar or (
92-
reduce_to_scalar and getattr(pytensor_type, "ndim", None) == 0
93-
):
94-
return numba_dtype
95-
return numba.types.Array(numba_dtype, pytensor_type.ndim, layout)
96-
elif isinstance(pytensor_type, ScalarType):
97-
dtype = np.dtype(pytensor_type.dtype)
98-
numba_dtype = numba.from_dtype(dtype)
107+
dtype = pytensor_type.numpy_dtype
108+
numba_dtype = numba.from_dtype(dtype)
109+
if force_scalar or (reduce_to_scalar and getattr(pytensor_type, "ndim", None) == 0):
99110
return numba_dtype
100-
else:
101-
raise NotImplementedError(f"Numba type not implemented for {pytensor_type}")
111+
return numba.types.Array(numba_dtype, pytensor_type.ndim, layout)
102112

103113

104114
def create_numba_signature(
105-
node_or_fgraph: Union[FunctionGraph, Apply],
106-
force_scalar: bool = False,
107-
reduce_to_scalar: bool = False,
115+
node_or_fgraph: Union[FunctionGraph, Apply], **kwargs
108116
) -> numba.types.Type:
109117
"""Create a Numba type for the signature of an `Apply` node or `FunctionGraph`."""
110118
input_types = []
111119
for inp in node_or_fgraph.inputs:
112-
input_types.append(
113-
get_numba_type(
114-
inp.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar
115-
)
116-
)
120+
input_types.append(get_numba_type(inp.type, **kwargs))
117121

118122
output_types = []
119123
for out in node_or_fgraph.outputs:
120-
output_types.append(
121-
get_numba_type(
122-
out.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar
123-
)
124-
)
124+
output_types.append(get_numba_type(out.type, **kwargs))
125+
126+
if isinstance(node_or_fgraph, FunctionGraph):
127+
return numba.types.Tuple(output_types)(*input_types)
125128

126129
if len(output_types) > 1:
127130
return numba.types.Tuple(output_types)(*input_types)

tests/link/numba/test_basic.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from pytensor.link.numba.dispatch import numba_const_convert
2828
from pytensor.link.numba.linker import NumbaLinker
2929
from pytensor.raise_op import assert_op
30+
from pytensor.sparse.type import SparseTensorType
3031
from pytensor.tensor import blas
3132
from pytensor.tensor import subtensor as at_subtensor
3233
from pytensor.tensor.elemwise import Elemwise
@@ -252,26 +253,21 @@ def assert_fn(x, y):
252253

253254

254255
@pytest.mark.parametrize(
255-
"v, expected, force_scalar, not_implemented",
256+
"v, expected, force_scalar",
256257
[
257-
(MyType(), None, False, True),
258-
(aes.float32, numba.types.float32, False, False),
259-
(at.fscalar, numba.types.Array(numba.types.float32, 0, "A"), False, False),
260-
(at.fscalar, numba.types.float32, True, False),
261-
(at.lvector, numba.types.int64[:], False, False),
262-
(at.dmatrix, numba.types.float64[:, :], False, False),
263-
(at.dmatrix, numba.types.float64, True, False),
258+
(MyType(), numba.types.pyobject, False),
259+
(SparseTensorType("csc", dtype=np.float64), numba.types.pyobject, False),
260+
(aes.float32, numba.types.float32, False),
261+
(at.fscalar, numba.types.Array(numba.types.float32, 0, "A"), False),
262+
(at.fscalar, numba.types.float32, True),
263+
(at.lvector, numba.types.int64[:], False),
264+
(at.dmatrix, numba.types.float64[:, :], False),
265+
(at.dmatrix, numba.types.float64, True),
264266
],
265267
)
266-
def test_get_numba_type(v, expected, force_scalar, not_implemented):
267-
cm = (
268-
contextlib.suppress()
269-
if not not_implemented
270-
else pytest.raises(NotImplementedError)
271-
)
272-
with cm:
273-
res = numba_basic.get_numba_type(v, force_scalar=force_scalar)
274-
assert res == expected
268+
def test_get_numba_type(v, expected, force_scalar):
269+
res = numba_basic.get_numba_type(v, force_scalar=force_scalar)
270+
assert res == expected
275271

276272

277273
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)