Skip to content

Commit 6f6e8bb

Browse files
committed
Add support for TypedList in numba backend
1 parent 1e79035 commit 6f6e8bb

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
)
4949
from pytensor.tensor.type import TensorType
5050
from pytensor.tensor.type_other import MakeSlice, NoneConst
51+
from pytensor.typed_list import TypedListType
5152

5253

5354
def global_numba_func(func):
@@ -130,6 +131,8 @@ def get_numba_type(
130131
return CSCMatrixType(numba_dtype)
131132

132133
raise NotImplementedError()
134+
elif isinstance(pytensor_type, TypedListType):
135+
return numba.types.List(get_numba_type(pytensor_type.ttype))
133136
else:
134137
raise NotImplementedError(f"Numba type not implemented for {pytensor_type}")
135138

0 commit comments

Comments
 (0)