Skip to content

Commit b269182

Browse files
committed
Support single multidimensional indexing in Numba via rewrites
1 parent a14cb2b commit b269182

File tree

2 files changed

+157
-6
lines changed

2 files changed

+157
-6
lines changed

pytensor/tensor/rewriting/subtensor.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytensor
88
import pytensor.scalar.basic as ps
99
from pytensor import compile
10+
from pytensor.compile import optdb
1011
from pytensor.graph.basic import Constant, Variable
1112
from pytensor.graph.rewriting.basic import (
1213
WalkingGraphRewriter,
@@ -1934,3 +1935,111 @@ def local_blockwise_advanced_inc_subtensor(fgraph, node):
19341935
new_out = op.core_op.make_node(x, y, *symbolic_idxs).outputs
19351936
copy_stack_trace(node.outputs, new_out)
19361937
return new_out
1938+
1939+
1940+
@node_rewriter(tracks=[AdvancedSubtensor])
1941+
def ravel_multidimensional_bool_idx(fgraph, node):
1942+
"""Convert multidimensional boolean indexing into equivalent vector boolean index, supported by Numba
1943+
1944+
x[eye(3, dtype=bool)] -> x.ravel()[eye(3).ravel()]
1945+
"""
1946+
x, *idxs = node.inputs
1947+
1948+
if any(
1949+
isinstance(idx.type, TensorType) and idx.type.dtype.startswith("int")
1950+
for idx in idxs
1951+
):
1952+
# Get out if there are any other advanced indexes
1953+
return None
1954+
1955+
bool_idxs = [
1956+
(i, idx)
1957+
for i, idx in enumerate(idxs)
1958+
if (isinstance(idx.type, TensorType) and idx.dtype == "bool")
1959+
]
1960+
1961+
if len(bool_idxs) != 1:
1962+
# Get out if there are no or multiple boolean idxs
1963+
return None
1964+
1965+
[(bool_idx_pos, bool_idx)] = bool_idxs
1966+
bool_idx_ndim = bool_idx.type.ndim
1967+
if bool_idx.type.ndim < 2:
1968+
# No need to do anything if it's a vector or scalar, as it's already supported by Numba
1969+
return None
1970+
1971+
x_shape = x.shape
1972+
raveled_x = x.reshape(
1973+
(*x_shape[:bool_idx_pos], -1, *x_shape[bool_idx_pos + bool_idx_ndim :])
1974+
)
1975+
1976+
raveled_bool_idx = bool_idx.ravel()
1977+
new_idxs = list(idxs)
1978+
new_idxs[bool_idx_pos] = raveled_bool_idx
1979+
1980+
return [raveled_x[tuple(new_idxs)]]
1981+
1982+
1983+
@node_rewriter(tracks=[AdvancedSubtensor])
1984+
def ravel_multidimensional_int_idx(fgraph, node):
1985+
"""Convert multidimensional integer indexing into equivalent vector integer index, supported by Numba
1986+
1987+
x[eye(3, dtype=int)] -> x[eye(3).ravel()].reshape((3, 3))
1988+
1989+
1990+
NOTE: This is very similar to the rewrite `local_replace_AdvancedSubtensor` except it also handles non-full slices
1991+
1992+
x[eye(3, dtype=int), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
1993+
"""
1994+
x, *idxs = node.inputs
1995+
1996+
if any(
1997+
isinstance(idx.type, TensorType) and idx.type.dtype.startswith("bool")
1998+
for idx in idxs
1999+
):
2000+
# Get out if there are any other advanced indexes
2001+
return None
2002+
2003+
int_idxs = [
2004+
(i, idx)
2005+
for i, idx in enumerate(idxs)
2006+
if (isinstance(idx.type, TensorType) and idx.dtype.startswith("int"))
2007+
]
2008+
2009+
if len(int_idxs) != 1:
2010+
# Get out if there are no or multiple integer idxs
2011+
return None
2012+
2013+
[(int_idx_pos, int_idx)] = int_idxs
2014+
if int_idx.type.ndim < 2:
2015+
# No need to do anything if it's a vector or scalar, as it's already supported by Numba
2016+
return None
2017+
2018+
raveled_int_idx = int_idx.ravel()
2019+
new_idxs = list(idxs)
2020+
new_idxs[int_idx_pos] = raveled_int_idx
2021+
raveled_subtensor = x[tuple(new_idxs)]
2022+
2023+
# Reshape into correct shape
2024+
# Because we only allow one advanced indexing, the output dimension corresponding to the raveled integer indexing
2025+
# must match the input position. If there were multiple advanced indexes, this could have been forcefully moved to the front
2026+
raveled_shape = raveled_subtensor.shape
2027+
unraveled_shape = (
2028+
*raveled_shape[:int_idx_pos],
2029+
*int_idx.shape,
2030+
*raveled_shape[int_idx_pos + 1 :],
2031+
)
2032+
return [raveled_subtensor.reshape(unraveled_shape)]
2033+
2034+
2035+
optdb["specialize"].register(
2036+
ravel_multidimensional_bool_idx.__name__,
2037+
ravel_multidimensional_bool_idx,
2038+
"numba",
2039+
)
2040+
2041+
optdb["specialize"].register(
2042+
ravel_multidimensional_int_idx.__name__,
2043+
ravel_multidimensional_int_idx,
2044+
"numba",
2045+
)

tests/link/numba/test_subtensor.py

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
inc_subtensor,
2020
set_subtensor,
2121
)
22-
from tests.link.numba.test_basic import compare_numba_and_py
23-
22+
from tests.link.numba.test_basic import compare_numba_and_py, numba_mode
2423

2524
rng = np.random.default_rng(sum(map(ord, "Numba subtensors")))
2625

@@ -74,6 +73,7 @@ def test_AdvancedSubtensor1_out_of_bounds():
7473
@pytest.mark.parametrize(
7574
"x, indices, objmode_needed",
7675
[
76+
# Single vector indexing (supported natively by Numba)
7777
(
7878
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
7979
(0, [1, 2, 2, 3]),
@@ -84,25 +84,63 @@ def test_AdvancedSubtensor1_out_of_bounds():
8484
(np.array([True, False, False])),
8585
False,
8686
),
87+
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3]), True),
88+
# Single multidimensional indexing (supported after specialization rewrites)
89+
(
90+
as_tensor(np.arange(3 * 3).reshape((3, 3))),
91+
(np.eye(3).astype(int)),
92+
False,
93+
),
8794
(
8895
as_tensor(np.arange(3 * 3).reshape((3, 3))),
8996
(np.eye(3).astype(bool)),
97+
False,
98+
),
99+
(
100+
as_tensor(np.arange(3 * 3 * 2).reshape((3, 3, 2))),
101+
(np.eye(3).astype(int)),
102+
False,
103+
),
104+
(
105+
as_tensor(np.arange(3 * 3 * 2).reshape((3, 3, 2))),
106+
(np.eye(3).astype(bool)),
107+
False,
108+
),
109+
(
110+
as_tensor(np.arange(2 * 3 * 3).reshape((2, 3, 3))),
111+
(slice(2, None), np.eye(3).astype(int)),
112+
False,
113+
),
114+
(
115+
as_tensor(np.arange(2 * 3 * 3).reshape((2, 3, 3))),
116+
(slice(2, None), np.eye(3).astype(bool)),
117+
False,
118+
),
119+
# Multiple advanced indexing, only supported in obj mode
120+
(
121+
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
122+
(slice(None), [1, 2], [3, 4]),
90123
True,
91124
),
92-
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3]), True),
93125
(
94126
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
95127
([1, 2], slice(None), [3, 4]),
96128
True,
97129
),
130+
(
131+
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
132+
([[1, 2], [2, 1]], [0, 0]),
133+
True,
134+
),
98135
],
99136
)
100137
@pytest.mark.filterwarnings("error")
101138
def test_AdvancedSubtensor(x, indices, objmode_needed):
102139
"""Test NumPy's advanced indexing in more than one dimension."""
103-
out_pt = x[indices]
140+
x_pt = x.type()
141+
out_pt = x_pt[indices]
104142
assert isinstance(out_pt.owner.op, AdvancedSubtensor)
105-
out_fg = FunctionGraph([], [out_pt])
143+
out_fg = FunctionGraph([x_pt], [out_pt])
106144
with (
107145
pytest.warns(
108146
UserWarning,
@@ -111,7 +149,11 @@ def test_AdvancedSubtensor(x, indices, objmode_needed):
111149
if objmode_needed
112150
else contextlib.nullcontext()
113151
):
114-
compare_numba_and_py(out_fg, [])
152+
compare_numba_and_py(
153+
out_fg,
154+
[x.data],
155+
numba_mode=numba_mode.including("specialize"),
156+
)
115157

116158

117159
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)