Skip to content

Commit bdfba42

Browse files
brandonwillardtwiecki
authored andcommitted
Make get_numba_type dispatch on Type
1 parent aaeb88a commit bdfba42

File tree

2 files changed

+48
-49
lines changed

2 files changed

+48
-49
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
)
3232
from pytensor.scalar.basic import ScalarType
3333
from pytensor.scalar.math import Softplus
34+
from pytensor.sparse.type import SparseTensorType
3435
from pytensor.tensor.blas import BatchedDot
3536
from pytensor.tensor.math import Dot
3637
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
@@ -73,14 +74,33 @@ def numba_vectorize(*args, **kwargs):
7374
)
7475

7576

76-
def get_numba_type(
77-
pytensor_type: Type,
77+
@singledispatch
78+
def get_numba_type(pytensor_type: Type, **kwargs) -> numba.types.Type:
79+
r"""Create a Numba type object for a :class:`Type`."""
80+
return numba.types.pyobject
81+
82+
83+
@get_numba_type.register(SparseTensorType)
84+
def get_numba_type_SparseType(pytensor_type, **kwargs):
85+
# This is needed to differentiate `SparseTensorType` from `TensorType`
86+
return numba.types.pyobject
87+
88+
89+
@get_numba_type.register(ScalarType)
90+
def get_numba_type_ScalarType(pytensor_type, **kwargs):
91+
dtype = np.dtype(pytensor_type.dtype)
92+
numba_dtype = numba.from_dtype(dtype)
93+
return numba_dtype
94+
95+
96+
@get_numba_type.register(TensorType)
97+
def get_numba_type_TensorType(
98+
pytensor_type,
7899
layout: str = "A",
79100
force_scalar: bool = False,
80101
reduce_to_scalar: bool = False,
81-
) -> numba.types.Type:
82-
r"""Create a Numba type object for a :class:`Type`.
83-
102+
):
103+
r"""
84104
Parameters
85105
----------
86106
pytensor_type
@@ -92,44 +112,27 @@ def get_numba_type(
92112
reduce_to_scalar
93113
Return Numba scalars for zero dimensional :class:`TensorType`\s.
94114
"""
95-
96-
if isinstance(pytensor_type, TensorType):
97-
dtype = pytensor_type.numpy_dtype
98-
numba_dtype = numba.from_dtype(dtype)
99-
if force_scalar or (
100-
reduce_to_scalar and getattr(pytensor_type, "ndim", None) == 0
101-
):
102-
return numba_dtype
103-
return numba.types.Array(numba_dtype, pytensor_type.ndim, layout)
104-
elif isinstance(pytensor_type, ScalarType):
105-
dtype = np.dtype(pytensor_type.dtype)
106-
numba_dtype = numba.from_dtype(dtype)
115+
dtype = pytensor_type.numpy_dtype
116+
numba_dtype = numba.from_dtype(dtype)
117+
if force_scalar or (reduce_to_scalar and getattr(pytensor_type, "ndim", None) == 0):
107118
return numba_dtype
108-
else:
109-
raise NotImplementedError(f"Numba type not implemented for {pytensor_type}")
119+
return numba.types.Array(numba_dtype, pytensor_type.ndim, layout)
110120

111121

112122
def create_numba_signature(
113-
node_or_fgraph: Union[FunctionGraph, Apply],
114-
force_scalar: bool = False,
115-
reduce_to_scalar: bool = False,
123+
node_or_fgraph: Union[FunctionGraph, Apply], **kwargs
116124
) -> numba.types.Type:
117125
"""Create a Numba type for the signature of an `Apply` node or `FunctionGraph`."""
118126
input_types = []
119127
for inp in node_or_fgraph.inputs:
120-
input_types.append(
121-
get_numba_type(
122-
inp.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar
123-
)
124-
)
128+
input_types.append(get_numba_type(inp.type, **kwargs))
125129

126130
output_types = []
127131
for out in node_or_fgraph.outputs:
128-
output_types.append(
129-
get_numba_type(
130-
out.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar
131-
)
132-
)
132+
output_types.append(get_numba_type(out.type, **kwargs))
133+
134+
if isinstance(node_or_fgraph, FunctionGraph):
135+
return numba.types.Tuple(output_types)(*input_types)
133136

134137
if len(output_types) > 1:
135138
return numba.types.Tuple(output_types)(*input_types)

tests/link/numba/test_basic.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from pytensor.link.numba.dispatch import numba_const_convert
2828
from pytensor.link.numba.linker import NumbaLinker
2929
from pytensor.raise_op import assert_op
30+
from pytensor.sparse.type import SparseTensorType
3031
from pytensor.tensor import blas
3132
from pytensor.tensor import subtensor as at_subtensor
3233
from pytensor.tensor.elemwise import Elemwise
@@ -252,26 +253,21 @@ def assert_fn(x, y):
252253

253254

254255
@pytest.mark.parametrize(
255-
"v, expected, force_scalar, not_implemented",
256+
"v, expected, force_scalar",
256257
[
257-
(MyType(), None, False, True),
258-
(aes.float32, numba.types.float32, False, False),
259-
(at.fscalar, numba.types.Array(numba.types.float32, 0, "A"), False, False),
260-
(at.fscalar, numba.types.float32, True, False),
261-
(at.lvector, numba.types.int64[:], False, False),
262-
(at.dmatrix, numba.types.float64[:, :], False, False),
263-
(at.dmatrix, numba.types.float64, True, False),
258+
(MyType(), numba.types.pyobject, False),
259+
(SparseTensorType("csc", dtype=np.float64), numba.types.pyobject, False),
260+
(aes.float32, numba.types.float32, False),
261+
(at.fscalar, numba.types.Array(numba.types.float32, 0, "A"), False),
262+
(at.fscalar, numba.types.float32, True),
263+
(at.lvector, numba.types.int64[:], False),
264+
(at.dmatrix, numba.types.float64[:, :], False),
265+
(at.dmatrix, numba.types.float64, True),
264266
],
265267
)
266-
def test_get_numba_type(v, expected, force_scalar, not_implemented):
267-
cm = (
268-
contextlib.suppress()
269-
if not not_implemented
270-
else pytest.raises(NotImplementedError)
271-
)
272-
with cm:
273-
res = numba_basic.get_numba_type(v, force_scalar=force_scalar)
274-
assert res == expected
268+
def test_get_numba_type(v, expected, force_scalar):
269+
res = numba_basic.get_numba_type(v, force_scalar=force_scalar)
270+
assert res == expected
275271

276272

277273
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)