31
31
)
32
32
from pytensor .scalar .basic import ScalarType
33
33
from pytensor .scalar .math import Softplus
34
+ from pytensor .sparse .type import SparseTensorType
34
35
from pytensor .tensor .blas import BatchedDot
35
36
from pytensor .tensor .math import Dot
36
37
from pytensor .tensor .shape import Reshape , Shape , Shape_i , SpecifyShape
@@ -73,14 +74,33 @@ def numba_vectorize(*args, **kwargs):
73
74
)
74
75
75
76
76
- def get_numba_type (
77
- pytensor_type : Type ,
77
+ @singledispatch
78
+ def get_numba_type (pytensor_type : Type , ** kwargs ) -> numba .types .Type :
79
+ r"""Create a Numba type object for a :class:`Type`."""
80
+ return numba .types .pyobject
81
+
82
+
83
+ @get_numba_type .register (SparseTensorType )
84
+ def get_numba_type_SparseType (pytensor_type , ** kwargs ):
85
+ # This is needed to differentiate `SparseTensorType` from `TensorType`
86
+ return numba .types .pyobject
87
+
88
+
89
+ @get_numba_type .register (ScalarType )
90
+ def get_numba_type_ScalarType (pytensor_type , ** kwargs ):
91
+ dtype = np .dtype (pytensor_type .dtype )
92
+ numba_dtype = numba .from_dtype (dtype )
93
+ return numba_dtype
94
+
95
+
96
+ @get_numba_type .register (TensorType )
97
+ def get_numba_type_TensorType (
98
+ pytensor_type ,
78
99
layout : str = "A" ,
79
100
force_scalar : bool = False ,
80
101
reduce_to_scalar : bool = False ,
81
- ) -> numba .types .Type :
82
- r"""Create a Numba type object for a :class:`Type`.
83
-
102
+ ):
103
+ r"""
84
104
Parameters
85
105
----------
86
106
pytensor_type
@@ -92,44 +112,27 @@ def get_numba_type(
92
112
reduce_to_scalar
93
113
Return Numba scalars for zero dimensional :class:`TensorType`\s.
94
114
"""
95
-
96
- if isinstance (pytensor_type , TensorType ):
97
- dtype = pytensor_type .numpy_dtype
98
- numba_dtype = numba .from_dtype (dtype )
99
- if force_scalar or (
100
- reduce_to_scalar and getattr (pytensor_type , "ndim" , None ) == 0
101
- ):
102
- return numba_dtype
103
- return numba .types .Array (numba_dtype , pytensor_type .ndim , layout )
104
- elif isinstance (pytensor_type , ScalarType ):
105
- dtype = np .dtype (pytensor_type .dtype )
106
- numba_dtype = numba .from_dtype (dtype )
115
+ dtype = pytensor_type .numpy_dtype
116
+ numba_dtype = numba .from_dtype (dtype )
117
+ if force_scalar or (reduce_to_scalar and getattr (pytensor_type , "ndim" , None ) == 0 ):
107
118
return numba_dtype
108
- else :
109
- raise NotImplementedError (f"Numba type not implemented for { pytensor_type } " )
119
+ return numba .types .Array (numba_dtype , pytensor_type .ndim , layout )
110
120
111
121
112
122
def create_numba_signature (
113
- node_or_fgraph : Union [FunctionGraph , Apply ],
114
- force_scalar : bool = False ,
115
- reduce_to_scalar : bool = False ,
123
+ node_or_fgraph : Union [FunctionGraph , Apply ], ** kwargs
116
124
) -> numba .types .Type :
117
125
"""Create a Numba type for the signature of an `Apply` node or `FunctionGraph`."""
118
126
input_types = []
119
127
for inp in node_or_fgraph .inputs :
120
- input_types .append (
121
- get_numba_type (
122
- inp .type , force_scalar = force_scalar , reduce_to_scalar = reduce_to_scalar
123
- )
124
- )
128
+ input_types .append (get_numba_type (inp .type , ** kwargs ))
125
129
126
130
output_types = []
127
131
for out in node_or_fgraph .outputs :
128
- output_types .append (
129
- get_numba_type (
130
- out .type , force_scalar = force_scalar , reduce_to_scalar = reduce_to_scalar
131
- )
132
- )
132
+ output_types .append (get_numba_type (out .type , ** kwargs ))
133
+
134
+ if isinstance (node_or_fgraph , FunctionGraph ):
135
+ return numba .types .Tuple (output_types )(* input_types )
133
136
134
137
if len (output_types ) > 1 :
135
138
return numba .types .Tuple (output_types )(* input_types )
0 commit comments