Skip to content

Commit b75c18f

Browse files
authored
Add numba overload for Nonzero (#1289)
* Add numba overload for Nonzero * added numba backend and testsfor Nonzero * Added numba backend for Nonzero * Modified the tests and the dispatch for efficiency
1 parent 39704d1 commit b75c18f

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from pytensor.scalar.basic import ScalarType
3434
from pytensor.scalar.math import Softplus
3535
from pytensor.sparse import SparseTensorType
36+
from pytensor.tensor.basic import Nonzero
3637
from pytensor.tensor.blas import BatchedDot
3738
from pytensor.tensor.math import Dot
3839
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
@@ -657,3 +658,15 @@ def ifelse(cond, *args):
657658
return res[0]
658659

659660
return ifelse
661+
662+
663+
@numba_funcify.register(Nonzero)
664+
def numba_funcify_Nonzero(op, node, **kwargs):
665+
@numba_njit
666+
def nonzero(a):
667+
result_tuple = np.nonzero(a)
668+
if a.ndim == 1:
669+
return result_tuple[0]
670+
return list(result_tuple)
671+
672+
return nonzero

tests/link/numba/test_basic.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,6 @@ def assert_fn(x, y):
293293
)
294294
test_inputs_copy = (inp.copy() for inp in test_inputs) if inplace else test_inputs
295295
numba_res = pytensor_numba_fn(*test_inputs_copy)
296-
297296
if isinstance(graph_outputs, tuple | list):
298297
for j, p in zip(numba_res, py_res, strict=True):
299298
assert_fn(j, p)
@@ -899,3 +898,17 @@ def test_function_overhead(mode, benchmark):
899898
assert np.sum(fn(test_x)) == 1000
900899

901900
benchmark(fn, test_x)
901+
902+
903+
@pytest.mark.parametrize(
904+
"input_data",
905+
[np.array([1, 0, 3]), np.array([[0, 1], [2, 0]]), np.array([[0, 0], [0, 0]])],
906+
)
907+
def test_Nonzero(input_data):
908+
a = pt.tensor("a", shape=(None,) * input_data.ndim)
909+
910+
graph_outputs = pt.nonzero(a)
911+
912+
compare_numba_and_py(
913+
graph_inputs=[a], graph_outputs=graph_outputs, test_inputs=[input_data]
914+
)

0 commit comments

Comments
 (0)