Skip to content

Commit a1e3775

Browse files
committed
Temporary patch for numba/numba#9554
1 parent 78b5120 commit a1e3775

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

pytensor/link/numba/dispatch/scalar.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
Composite,
2525
Identity,
2626
Mul,
27+
Pow,
2728
Reciprocal,
2829
ScalarOp,
2930
Second,
@@ -160,6 +161,21 @@ def numba_funcify_Switch(op, node, **kwargs):
160161
return numba_basic.global_numba_func(switch)
161162

162163

164+
@numba_funcify.register(Pow)
165+
def numba_funcify_Pow(op, node, **kwargs):
166+
pow_dtype = node.inputs[1].type.dtype
167+
168+
def pow(x, y):
169+
return x**y
170+
171+
# Work-around https://github.com/numba/numba/issues/9554
172+
# fast-math casuse kernel crash
173+
patch_kwargs = {}
174+
if pow_dtype.startswith("int"):
175+
patch_kwargs["fastmath"] = False
176+
return numba_basic.numba_njit(**patch_kwargs)(pow)
177+
178+
163179
def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op: str):
164180
"""Create a Numba-compatible N-ary function from a binary function."""
165181
unique_names = unique_name_generator(["binary_op_name"], suffix_sep="_")

0 commit comments

Comments
 (0)