31
31
)
32
32
from pytensor .scalar .basic import ScalarType
33
33
from pytensor .scalar .math import Softplus
34
+ from pytensor .sparse .type import SparseTensorType
34
35
from pytensor .tensor .blas import BatchedDot
35
36
from pytensor .tensor .math import Dot
36
37
from pytensor .tensor .shape import Reshape , Shape , Shape_i , SpecifyShape
@@ -65,14 +66,33 @@ def numba_vectorize(*args, **kwargs):
65
66
return numba .vectorize (* args , cache = config .numba__cache , ** kwargs )
66
67
67
68
68
- def get_numba_type (
69
- pytensor_type : Type ,
69
+ @singledispatch
70
+ def get_numba_type (pytensor_type : Type , ** kwargs ) -> numba .types .Type :
71
+ r"""Create a Numba type object for a :class:`Type`."""
72
+ return numba .types .pyobject
73
+
74
+
75
+ @get_numba_type .register (SparseTensorType )
76
+ def get_numba_type_SparseType (pytensor_type , ** kwargs ):
77
+ # This is needed to differentiate `SparseTensorType` from `TensorType`
78
+ return numba .types .pyobject
79
+
80
+
81
+ @get_numba_type .register (ScalarType )
82
+ def get_numba_type_ScalarType (pytensor_type , ** kwargs ):
83
+ dtype = np .dtype (pytensor_type .dtype )
84
+ numba_dtype = numba .from_dtype (dtype )
85
+ return numba_dtype
86
+
87
+
88
+ @get_numba_type .register (TensorType )
89
+ def get_numba_type_TensorType (
90
+ pytensor_type ,
70
91
layout : str = "A" ,
71
92
force_scalar : bool = False ,
72
93
reduce_to_scalar : bool = False ,
73
- ) -> numba .types .Type :
74
- r"""Create a Numba type object for a :class:`Type`.
75
-
94
+ ):
95
+ r"""
76
96
Parameters
77
97
----------
78
98
pytensor_type
@@ -84,44 +104,27 @@ def get_numba_type(
84
104
reduce_to_scalar
85
105
Return Numba scalars for zero dimensional :class:`TensorType`\s.
86
106
"""
87
-
88
- if isinstance (pytensor_type , TensorType ):
89
- dtype = pytensor_type .numpy_dtype
90
- numba_dtype = numba .from_dtype (dtype )
91
- if force_scalar or (
92
- reduce_to_scalar and getattr (pytensor_type , "ndim" , None ) == 0
93
- ):
94
- return numba_dtype
95
- return numba .types .Array (numba_dtype , pytensor_type .ndim , layout )
96
- elif isinstance (pytensor_type , ScalarType ):
97
- dtype = np .dtype (pytensor_type .dtype )
98
- numba_dtype = numba .from_dtype (dtype )
107
+ dtype = pytensor_type .numpy_dtype
108
+ numba_dtype = numba .from_dtype (dtype )
109
+ if force_scalar or (reduce_to_scalar and getattr (pytensor_type , "ndim" , None ) == 0 ):
99
110
return numba_dtype
100
- else :
101
- raise NotImplementedError (f"Numba type not implemented for { pytensor_type } " )
111
+ return numba .types .Array (numba_dtype , pytensor_type .ndim , layout )
102
112
103
113
104
114
def create_numba_signature (
105
- node_or_fgraph : Union [FunctionGraph , Apply ],
106
- force_scalar : bool = False ,
107
- reduce_to_scalar : bool = False ,
115
+ node_or_fgraph : Union [FunctionGraph , Apply ], ** kwargs
108
116
) -> numba .types .Type :
109
117
"""Create a Numba type for the signature of an `Apply` node or `FunctionGraph`."""
110
118
input_types = []
111
119
for inp in node_or_fgraph .inputs :
112
- input_types .append (
113
- get_numba_type (
114
- inp .type , force_scalar = force_scalar , reduce_to_scalar = reduce_to_scalar
115
- )
116
- )
120
+ input_types .append (get_numba_type (inp .type , ** kwargs ))
117
121
118
122
output_types = []
119
123
for out in node_or_fgraph .outputs :
120
- output_types .append (
121
- get_numba_type (
122
- out .type , force_scalar = force_scalar , reduce_to_scalar = reduce_to_scalar
123
- )
124
- )
124
+ output_types .append (get_numba_type (out .type , ** kwargs ))
125
+
126
+ if isinstance (node_or_fgraph , FunctionGraph ):
127
+ return numba .types .Tuple (output_types )(* input_types )
125
128
126
129
if len (output_types ) > 1 :
127
130
return numba .types .Tuple (output_types )(* input_types )
0 commit comments