Skip to content

Commit 99de2b0

Browse files
committed
Temporary patch for numba/numba#9554
1 parent e8e103d commit 99de2b0

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

pytensor/link/numba/dispatch/scalar.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
Composite,
2424
Identity,
2525
Mul,
26+
Pow,
2627
Reciprocal,
2728
ScalarOp,
2829
Second,
@@ -154,6 +155,21 @@ def numba_funcify_Switch(op, node, **kwargs):
154155
return numba_basic.global_numba_func(switch)
155156

156157

158+
@numba_funcify.register(Pow)
159+
def numba_funcify_Pow(op, node, **kwargs):
160+
pow_dtype = node.inputs[1].type.dtype
161+
162+
def pow(x, y):
163+
return x**y
164+
165+
# Work-around https://github.com/numba/numba/issues/9554
166+
# fast-math casuse kernel crash
167+
patch_kwargs = {}
168+
if pow_dtype.startswith("int"):
169+
patch_kwargs["fastmath"] = False
170+
return numba_basic.numba_njit(**patch_kwargs)(pow)
171+
172+
157173
def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op: str):
158174
"""Create a Numba-compatible N-ary function from a binary function."""
159175
unique_names = unique_name_generator(["binary_op_name"], suffix_sep="_")

0 commit comments

Comments
 (0)