Skip to content

Support more cases of advanced indexing in Numba #818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytensor
import pytensor.scalar.basic as ps
from pytensor import compile
from pytensor.compile import optdb
from pytensor.graph.basic import Constant, Variable
from pytensor.graph.rewriting.basic import (
WalkingGraphRewriter,
Expand Down Expand Up @@ -1932,3 +1933,111 @@ def local_blockwise_advanced_inc_subtensor(fgraph, node):
new_out = op.core_op.make_node(x, y, *symbolic_idxs).outputs
copy_stack_trace(node.outputs, new_out)
return new_out


@node_rewriter(tracks=[AdvancedSubtensor])
def ravel_multidimensional_bool_idx(fgraph, node):
"""Convert multidimensional boolean indexing into equivalent vector boolean index, supported by Numba

x[eye(3, dtype=bool)] -> x.ravel()[eye(3).ravel()]
"""
x, *idxs = node.inputs

if any(
isinstance(idx.type, TensorType) and idx.type.dtype.startswith("int")
for idx in idxs
):
# Get out if there are any other advanced indexes
return None

bool_idxs = [
(i, idx)
for i, idx in enumerate(idxs)
if (isinstance(idx.type, TensorType) and idx.dtype == "bool")
]

if len(bool_idxs) != 1:
# Get out if there are no or multiple boolean idxs
return None

[(bool_idx_pos, bool_idx)] = bool_idxs
bool_idx_ndim = bool_idx.type.ndim
if bool_idx.type.ndim < 2:
# No need to do anything if it's a vector or scalar, as it's already supported by Numba
return None

x_shape = x.shape
raveled_x = x.reshape(
(*x_shape[:bool_idx_pos], -1, *x_shape[bool_idx_pos + bool_idx_ndim :])
)

raveled_bool_idx = bool_idx.ravel()
new_idxs = list(idxs)
new_idxs[bool_idx_pos] = raveled_bool_idx

return [raveled_x[tuple(new_idxs)]]


@node_rewriter(tracks=[AdvancedSubtensor])
def ravel_multidimensional_int_idx(fgraph, node):
"""Convert multidimensional integer indexing into equivalent vector integer index, supported by Numba

x[eye(3, dtype=int)] -> x[eye(3).ravel()].reshape((3, 3))


NOTE: This is very similar to the rewrite `local_replace_AdvancedSubtensor` except it also handles non-full slices

x[eye(3, dtype=int), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
"""
x, *idxs = node.inputs

if any(
isinstance(idx.type, TensorType) and idx.type.dtype.startswith("bool")
for idx in idxs
):
# Get out if there are any other advanced indexes
return None

int_idxs = [
(i, idx)
for i, idx in enumerate(idxs)
if (isinstance(idx.type, TensorType) and idx.dtype.startswith("int"))
]

if len(int_idxs) != 1:
# Get out if there are no or multiple integer idxs
return None

[(int_idx_pos, int_idx)] = int_idxs
if int_idx.type.ndim < 2:
# No need to do anything if it's a vector or scalar, as it's already supported by Numba
return None

raveled_int_idx = int_idx.ravel()
new_idxs = list(idxs)
new_idxs[int_idx_pos] = raveled_int_idx
raveled_subtensor = x[tuple(new_idxs)]

# Reshape into correct shape
# Because we only allow one advanced indexing, the output dimension corresponding to the raveled integer indexing
# must match the input position. If there were multiple advanced indexes, this could have been forcefully moved to the front
raveled_shape = raveled_subtensor.shape
unraveled_shape = (
*raveled_shape[:int_idx_pos],
*int_idx.shape,
*raveled_shape[int_idx_pos + 1 :],
)
return [raveled_subtensor.reshape(unraveled_shape)]


optdb["specialize"].register(
ravel_multidimensional_bool_idx.__name__,
ravel_multidimensional_bool_idx,
"numba",
)

optdb["specialize"].register(
ravel_multidimensional_int_idx.__name__,
ravel_multidimensional_int_idx,
"numba",
)
53 changes: 48 additions & 5 deletions tests/link/numba/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
inc_subtensor,
set_subtensor,
)
from tests.link.numba.test_basic import compare_numba_and_py
from tests.link.numba.test_basic import compare_numba_and_py, numba_mode


rng = np.random.default_rng(sum(map(ord, "Numba subtensors")))
Expand Down Expand Up @@ -74,6 +74,7 @@ def test_AdvancedSubtensor1_out_of_bounds():
@pytest.mark.parametrize(
"x, indices, objmode_needed",
[
# Single vector indexing (supported natively by Numba)
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
(0, [1, 2, 2, 3]),
Expand All @@ -84,25 +85,63 @@ def test_AdvancedSubtensor1_out_of_bounds():
(np.array([True, False, False])),
False,
),
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3]), True),
# Single multidimensional indexing (supported after specialization rewrites)
(
as_tensor(np.arange(3 * 3).reshape((3, 3))),
(np.eye(3).astype(int)),
False,
),
(
as_tensor(np.arange(3 * 3).reshape((3, 3))),
(np.eye(3).astype(bool)),
False,
),
(
as_tensor(np.arange(3 * 3 * 2).reshape((3, 3, 2))),
(np.eye(3).astype(int)),
False,
),
(
as_tensor(np.arange(3 * 3 * 2).reshape((3, 3, 2))),
(np.eye(3).astype(bool)),
False,
),
(
as_tensor(np.arange(2 * 3 * 3).reshape((2, 3, 3))),
(slice(2, None), np.eye(3).astype(int)),
False,
),
(
as_tensor(np.arange(2 * 3 * 3).reshape((2, 3, 3))),
(slice(2, None), np.eye(3).astype(bool)),
False,
),
# Multiple advanced indexing, only supported in obj mode
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
(slice(None), [1, 2], [3, 4]),
True,
),
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3]), True),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([1, 2], slice(None), [3, 4]),
True,
),
(
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([[1, 2], [2, 1]], [0, 0]),
True,
),
],
)
@pytest.mark.filterwarnings("error")
def test_AdvancedSubtensor(x, indices, objmode_needed):
"""Test NumPy's advanced indexing in more than one dimension."""
out_pt = x[indices]
x_pt = x.type()
out_pt = x_pt[indices]
assert isinstance(out_pt.owner.op, AdvancedSubtensor)
out_fg = FunctionGraph([], [out_pt])
out_fg = FunctionGraph([x_pt], [out_pt])
with (
pytest.warns(
UserWarning,
Expand All @@ -111,7 +150,11 @@ def test_AdvancedSubtensor(x, indices, objmode_needed):
if objmode_needed
else contextlib.nullcontext()
):
compare_numba_and_py(out_fg, [])
compare_numba_and_py(
out_fg,
[x.data],
numba_mode=numba_mode.including("specialize"),
)


@pytest.mark.parametrize(
Expand Down
Loading