From 60bd39fc7eedf3f4b5395a6a78b21f3db86e7480 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 11 Jun 2024 19:18:08 +0200 Subject: [PATCH 1/2] Move numba subtensor functionality to its own module --- pytensor/link/numba/dispatch/__init__.py | 1 + pytensor/link/numba/dispatch/basic.py | 220 -------------------- pytensor/link/numba/dispatch/subtensor.py | 228 +++++++++++++++++++++ tests/link/numba/test_basic.py | 213 -------------------- tests/link/numba/test_subtensor.py | 234 ++++++++++++++++++++++ 5 files changed, 463 insertions(+), 433 deletions(-) create mode 100644 pytensor/link/numba/dispatch/subtensor.py create mode 100644 tests/link/numba/test_subtensor.py diff --git a/pytensor/link/numba/dispatch/__init__.py b/pytensor/link/numba/dispatch/__init__.py index 9810e14178..6dd0e8211b 100644 --- a/pytensor/link/numba/dispatch/__init__.py +++ b/pytensor/link/numba/dispatch/__init__.py @@ -11,5 +11,6 @@ import pytensor.link.numba.dispatch.scan import pytensor.link.numba.dispatch.sparse import pytensor.link.numba.dispatch.slinalg +import pytensor.link.numba.dispatch.subtensor # isort: on diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 4e9830d627..a341231674 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -29,7 +29,6 @@ from pytensor.link.utils import ( compile_function_src, fgraph_to_python, - unique_name_generator, ) from pytensor.scalar.basic import ScalarType from pytensor.scalar.math import Softplus @@ -38,14 +37,6 @@ from pytensor.tensor.math import Dot from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from pytensor.tensor.slinalg import Solve -from pytensor.tensor.subtensor import ( - AdvancedIncSubtensor, - AdvancedIncSubtensor1, - AdvancedSubtensor, - AdvancedSubtensor1, - IncSubtensor, - Subtensor, -) from pytensor.tensor.type import TensorType from pytensor.tensor.type_other import MakeSlice, NoneConst @@ -479,217 +470,6 @@ def numba_funcify_FunctionGraph( ) -def create_index_func(node, objmode=False): - """Create a Python function that assembles and uses an index on an array.""" - - unique_names = unique_name_generator( - ["subtensor", "incsubtensor", "z"], suffix_sep="_" - ) - - def convert_indices(indices, entry): - if indices and isinstance(entry, Type): - rval = indices.pop(0) - return unique_names(rval) - elif isinstance(entry, slice): - return ( - f"slice({convert_indices(indices, entry.start)}, " - f"{convert_indices(indices, entry.stop)}, " - f"{convert_indices(indices, entry.step)})" - ) - elif isinstance(entry, type(None)): - return "None" - else: - raise ValueError() - - set_or_inc = isinstance( - node.op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor - ) - index_start_idx = 1 + int(set_or_inc) - - input_names = [unique_names(v, force_unique=True) for v in node.inputs] - op_indices = list(node.inputs[index_start_idx:]) - idx_list = getattr(node.op, "idx_list", None) - - indices_creation_src = ( - tuple(convert_indices(op_indices, idx) for idx in idx_list) - if idx_list - else tuple(input_names[index_start_idx:]) - ) - - if len(indices_creation_src) == 1: - indices_creation_src = f"indices = ({indices_creation_src[0]},)" - else: - indices_creation_src = ", ".join(indices_creation_src) - indices_creation_src = f"indices = ({indices_creation_src})" - - if set_or_inc: - fn_name = "incsubtensor" - if node.op.inplace: - index_prologue = f"z = {input_names[0]}" - else: - index_prologue = f"z = np.copy({input_names[0]})" - - if node.inputs[1].ndim == 0: - # TODO FIXME: This is a hack to get around a weird Numba typing - # issue. See https://github.com/numba/numba/issues/6000 - y_name = f"{input_names[1]}.item()" - else: - y_name = input_names[1] - - if node.op.set_instead_of_inc: - index_body = f"z[indices] = {y_name}" - else: - index_body = f"z[indices] += {y_name}" - else: - fn_name = "subtensor" - index_prologue = "" - index_body = f"z = {input_names[0]}[indices]" - - if objmode: - output_var = node.outputs[0] - - if not set_or_inc: - # Since `z` is being "created" while in object mode, it's - # considered an "outgoing" variable and needs to be manually typed - output_sig = f"z='{output_var.dtype}[{', '.join([':'] * output_var.ndim)}]'" - else: - output_sig = "" - - index_body = f""" - with objmode({output_sig}): - {index_body} - """ - - subtensor_def_src = f""" -def {fn_name}({", ".join(input_names)}): - {index_prologue} - {indices_creation_src} - {index_body} - return np.asarray(z) - """ - - return subtensor_def_src - - -@numba_funcify.register(Subtensor) -@numba_funcify.register(AdvancedSubtensor1) -def numba_funcify_Subtensor(op, node, **kwargs): - objmode = isinstance(op, AdvancedSubtensor) - if objmode: - warnings.warn( - ("Numba will use object mode to allow run " "AdvancedSubtensor."), - UserWarning, - ) - - subtensor_def_src = create_index_func(node, objmode=objmode) - - global_env = {"np": np} - if objmode: - global_env["objmode"] = numba.objmode - - subtensor_fn = compile_function_src( - subtensor_def_src, "subtensor", {**globals(), **global_env} - ) - - return numba_njit(subtensor_fn, boundscheck=True) - - -@numba_funcify.register(IncSubtensor) -def numba_funcify_IncSubtensor(op, node, **kwargs): - objmode = isinstance(op, AdvancedIncSubtensor) - if objmode: - warnings.warn( - ("Numba will use object mode to allow run " "AdvancedIncSubtensor."), - UserWarning, - ) - - incsubtensor_def_src = create_index_func(node, objmode=objmode) - - global_env = {"np": np} - if objmode: - global_env["objmode"] = numba.objmode - - incsubtensor_fn = compile_function_src( - incsubtensor_def_src, "incsubtensor", {**globals(), **global_env} - ) - - return numba_njit(incsubtensor_fn, boundscheck=True) - - -@numba_funcify.register(AdvancedIncSubtensor1) -def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): - inplace = op.inplace - set_instead_of_inc = op.set_instead_of_inc - x, vals, idxs = node.inputs - # TODO: Add explicit expand_dims in make_node so we don't need to worry about this here - broadcast = vals.type.ndim < x.type.ndim or vals.type.broadcastable[0] - - if set_instead_of_inc: - if broadcast: - - @numba_njit(boundscheck=True) - def advancedincsubtensor1_inplace(x, val, idxs): - if val.ndim == x.ndim: - core_val = val[0] - elif val.ndim == 0: - # Workaround for https://github.com/numba/numba/issues/9573 - core_val = val.item() - else: - core_val = val - - for idx in idxs: - x[idx] = core_val - return x - - else: - - @numba_njit(boundscheck=True) - def advancedincsubtensor1_inplace(x, vals, idxs): - if not len(idxs) == len(vals): - raise ValueError("The number of indices and values must match.") - for idx, val in zip(idxs, vals): - x[idx] = val - return x - else: - if broadcast: - - @numba_njit(boundscheck=True) - def advancedincsubtensor1_inplace(x, val, idxs): - if val.ndim == x.ndim: - core_val = val[0] - elif val.ndim == 0: - # Workaround for https://github.com/numba/numba/issues/9573 - core_val = val.item() - else: - core_val = val - - for idx in idxs: - x[idx] += core_val - return x - - else: - - @numba_njit(boundscheck=True) - def advancedincsubtensor1_inplace(x, vals, idxs): - if not len(idxs) == len(vals): - raise ValueError("The number of indices and values must match.") - for idx, val in zip(idxs, vals): - x[idx] += val - return x - - if inplace: - return advancedincsubtensor1_inplace - - else: - - @numba_njit - def advancedincsubtensor1(x, vals, idxs): - x = x.copy() - return advancedincsubtensor1_inplace(x, vals, idxs) - - return advancedincsubtensor1 - - def deepcopyop(x): return copy(x) diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py new file mode 100644 index 0000000000..3d2f3f2901 --- /dev/null +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -0,0 +1,228 @@ +import warnings + +import numba +import numpy as np + +from pytensor.graph import Type +from pytensor.link.numba.dispatch import numba_funcify +from pytensor.link.numba.dispatch.basic import numba_njit +from pytensor.link.utils import compile_function_src, unique_name_generator +from pytensor.tensor.subtensor import ( + AdvancedIncSubtensor, + AdvancedIncSubtensor1, + AdvancedSubtensor, + AdvancedSubtensor1, + IncSubtensor, + Subtensor, +) + + +def create_index_func(node, objmode=False): + """Create a Python function that assembles and uses an index on an array.""" + + unique_names = unique_name_generator( + ["subtensor", "incsubtensor", "z"], suffix_sep="_" + ) + + def convert_indices(indices, entry): + if indices and isinstance(entry, Type): + rval = indices.pop(0) + return unique_names(rval) + elif isinstance(entry, slice): + return ( + f"slice({convert_indices(indices, entry.start)}, " + f"{convert_indices(indices, entry.stop)}, " + f"{convert_indices(indices, entry.step)})" + ) + elif isinstance(entry, type(None)): + return "None" + else: + raise ValueError() + + set_or_inc = isinstance( + node.op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor + ) + index_start_idx = 1 + int(set_or_inc) + + input_names = [unique_names(v, force_unique=True) for v in node.inputs] + op_indices = list(node.inputs[index_start_idx:]) + idx_list = getattr(node.op, "idx_list", None) + + indices_creation_src = ( + tuple(convert_indices(op_indices, idx) for idx in idx_list) + if idx_list + else tuple(input_names[index_start_idx:]) + ) + + if len(indices_creation_src) == 1: + indices_creation_src = f"indices = ({indices_creation_src[0]},)" + else: + indices_creation_src = ", ".join(indices_creation_src) + indices_creation_src = f"indices = ({indices_creation_src})" + + if set_or_inc: + fn_name = "incsubtensor" + if node.op.inplace: + index_prologue = f"z = {input_names[0]}" + else: + index_prologue = f"z = np.copy({input_names[0]})" + + if node.inputs[1].ndim == 0: + # TODO FIXME: This is a hack to get around a weird Numba typing + # issue. See https://github.com/numba/numba/issues/6000 + y_name = f"{input_names[1]}.item()" + else: + y_name = input_names[1] + + if node.op.set_instead_of_inc: + index_body = f"z[indices] = {y_name}" + else: + index_body = f"z[indices] += {y_name}" + else: + fn_name = "subtensor" + index_prologue = "" + index_body = f"z = {input_names[0]}[indices]" + + if objmode: + output_var = node.outputs[0] + + if not set_or_inc: + # Since `z` is being "created" while in object mode, it's + # considered an "outgoing" variable and needs to be manually typed + output_sig = f"z='{output_var.dtype}[{', '.join([':'] * output_var.ndim)}]'" + else: + output_sig = "" + + index_body = f""" + with objmode({output_sig}): + {index_body} + """ + + subtensor_def_src = f""" +def {fn_name}({", ".join(input_names)}): + {index_prologue} + {indices_creation_src} + {index_body} + return np.asarray(z) + """ + + return subtensor_def_src + + +@numba_funcify.register(Subtensor) +@numba_funcify.register(AdvancedSubtensor1) +def numba_funcify_Subtensor(op, node, **kwargs): + objmode = isinstance(op, AdvancedSubtensor) + if objmode: + warnings.warn( + ("Numba will use object mode to allow run " "AdvancedSubtensor."), + UserWarning, + ) + + subtensor_def_src = create_index_func(node, objmode=objmode) + + global_env = {"np": np} + if objmode: + global_env["objmode"] = numba.objmode + + subtensor_fn = compile_function_src( + subtensor_def_src, "subtensor", {**globals(), **global_env} + ) + + return numba_njit(subtensor_fn, boundscheck=True) + + +@numba_funcify.register(IncSubtensor) +def numba_funcify_IncSubtensor(op, node, **kwargs): + objmode = isinstance(op, AdvancedIncSubtensor) + if objmode: + warnings.warn( + ("Numba will use object mode to allow run " "AdvancedIncSubtensor."), + UserWarning, + ) + + incsubtensor_def_src = create_index_func(node, objmode=objmode) + + global_env = {"np": np} + if objmode: + global_env["objmode"] = numba.objmode + + incsubtensor_fn = compile_function_src( + incsubtensor_def_src, "incsubtensor", {**globals(), **global_env} + ) + + return numba_njit(incsubtensor_fn, boundscheck=True) + + +@numba_funcify.register(AdvancedIncSubtensor1) +def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): + inplace = op.inplace + set_instead_of_inc = op.set_instead_of_inc + x, vals, idxs = node.inputs + # TODO: Add explicit expand_dims in make_node so we don't need to worry about this here + broadcast = vals.type.ndim < x.type.ndim or vals.type.broadcastable[0] + + if set_instead_of_inc: + if broadcast: + + @numba_njit(boundscheck=True) + def advancedincsubtensor1_inplace(x, val, idxs): + if val.ndim == x.ndim: + core_val = val[0] + elif val.ndim == 0: + # Workaround for https://github.com/numba/numba/issues/9573 + core_val = val.item() + else: + core_val = val + + for idx in idxs: + x[idx] = core_val + return x + + else: + + @numba_njit(boundscheck=True) + def advancedincsubtensor1_inplace(x, vals, idxs): + if not len(idxs) == len(vals): + raise ValueError("The number of indices and values must match.") + for idx, val in zip(idxs, vals): + x[idx] = val + return x + else: + if broadcast: + + @numba_njit(boundscheck=True) + def advancedincsubtensor1_inplace(x, val, idxs): + if val.ndim == x.ndim: + core_val = val[0] + elif val.ndim == 0: + # Workaround for https://github.com/numba/numba/issues/9573 + core_val = val.item() + else: + core_val = val + + for idx in idxs: + x[idx] += core_val + return x + + else: + + @numba_njit(boundscheck=True) + def advancedincsubtensor1_inplace(x, vals, idxs): + if not len(idxs) == len(vals): + raise ValueError("The number of indices and values must match.") + for idx, val in zip(idxs, vals): + x[idx] += val + return x + + if inplace: + return advancedincsubtensor1_inplace + + else: + + @numba_njit + def advancedincsubtensor1(x, vals, idxs): + x = x.copy() + return advancedincsubtensor1_inplace(x, vals, idxs) + + return advancedincsubtensor1 diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 20ecdc3002..e9f3035504 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -33,7 +33,6 @@ from pytensor.raise_op import assert_op from pytensor.scalar.basic import ScalarOp, as_scalar from pytensor.tensor import blas -from pytensor.tensor import subtensor as pt_subtensor from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape @@ -362,218 +361,6 @@ def test_create_numba_signature(v, expected, force_scalar): assert res == expected -@pytest.mark.parametrize( - "x, indices", - [ - (pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), (1,)), - ( - pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - (slice(None)), - ), - (pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), (1, 2, 0)), - ( - pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - (slice(1, 2), 1, slice(None)), - ), - ], -) -def test_Subtensor(x, indices): - """Test NumPy's basic indexing.""" - out_pt = x[indices] - assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) - out_fg = FunctionGraph([], [out_pt]) - compare_numba_and_py(out_fg, []) - - -@pytest.mark.parametrize( - "x, indices", - [ - (pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2],)), - (pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1],)), - ], -) -def test_AdvancedSubtensor1(x, indices): - """Test NumPy's advanced indexing in one dimension.""" - out_pt = pt_subtensor.advanced_subtensor1(x, *indices) - assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor1) - out_fg = FunctionGraph([], [out_pt]) - compare_numba_and_py(out_fg, []) - - -def test_AdvancedSubtensor1_out_of_bounds(): - out_pt = pt_subtensor.advanced_subtensor1(np.arange(3), [4]) - assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor1) - out_fg = FunctionGraph([], [out_pt]) - with pytest.raises(IndexError): - compare_numba_and_py(out_fg, []) - - -@pytest.mark.parametrize( - "x, indices", - [ - (pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3])), - ( - pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - ([1, 2], slice(None), [3, 4]), - ), - ], -) -def test_AdvancedSubtensor(x, indices): - """Test NumPy's advanced indexing in more than one dimension.""" - out_pt = x[indices] - assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) - out_fg = FunctionGraph([], [out_pt]) - compare_numba_and_py(out_fg, []) - - -@pytest.mark.parametrize( - "x, y, indices", - [ - ( - pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - pt.as_tensor(np.array(10)), - (1,), - ), - ( - pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - pt.as_tensor(rng.poisson(size=(4, 5))), - (slice(None)), - ), - ( - pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - pt.as_tensor(np.array(10)), - (1, 2, 0), - ), - ( - pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - pt.as_tensor(rng.poisson(size=(1, 5))), - (slice(1, 2), 1, slice(None)), - ), - ], -) -def test_IncSubtensor(x, y, indices): - out_pt = pt.set_subtensor(x[indices], y) - assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) - out_fg = FunctionGraph([], [out_pt]) - compare_numba_and_py(out_fg, []) - - out_pt = pt.inc_subtensor(x[indices], y) - assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) - out_fg = FunctionGraph([], [out_pt]) - compare_numba_and_py(out_fg, []) - - x_pt = x.type() - out_pt = pt.set_subtensor(x_pt[indices], y, inplace=True) - assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_numba_and_py(out_fg, [x.data]) - - -@pytest.mark.parametrize( - "x, y, indices", - [ - ( - pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - pt.as_tensor(rng.poisson(size=(2, 4, 5))), - ([1, 2],), - ), - ( - pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - pt.as_tensor(rng.poisson(size=(2, 4, 5))), - ([1, 1],), - ), - # Broadcasting values - ( - pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - pt.as_tensor(rng.poisson(size=(1, 4, 5))), - ([0, 2, 0],), - ), - ( - pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - pt.as_tensor(rng.poisson(size=(5,))), - ([0, 2],), - ), - ( - pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - pt.as_tensor(rng.poisson(size=())), - ([2, 0],), - ), - ( - pt.as_tensor(np.arange(5)), - pt.as_tensor(rng.poisson(size=())), - ([2, 0],), - ), - ], -) -def test_AdvancedIncSubtensor1(x, y, indices): - out_pt = pt_subtensor.advanced_set_subtensor1(x, y, *indices) - assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor1) - out_fg = FunctionGraph([], [out_pt]) - compare_numba_and_py(out_fg, []) - - out_pt = pt_subtensor.advanced_inc_subtensor1(x, y, *indices) - assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor1) - out_fg = FunctionGraph([], [out_pt]) - compare_numba_and_py(out_fg, []) - - # With symbolic inputs - x_pt = x.type() - y_pt = y.type() - - out_pt = pt_subtensor.AdvancedIncSubtensor1(inplace=True)(x_pt, y_pt, *indices) - assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor1) - out_fg = FunctionGraph([x_pt, y_pt], [out_pt]) - compare_numba_and_py(out_fg, [x.data, y.data]) - - out_pt = pt_subtensor.AdvancedIncSubtensor1(set_instead_of_inc=True, inplace=True)( - x_pt, y_pt, *indices - ) - assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor1) - out_fg = FunctionGraph([x_pt, y_pt], [out_pt]) - compare_numba_and_py(out_fg, [x.data, y.data]) - - -@pytest.mark.parametrize( - "x, y, indices", - [ - ( - pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - pt.as_tensor(rng.poisson(size=(2, 5))), - ([1, 2], [2, 3]), - ), - ( - pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - pt.as_tensor(rng.poisson(size=(2, 4))), - ([1, 2], slice(None), [3, 4]), - ), - pytest.param( - pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - pt.as_tensor(rng.poisson(size=(2, 5))), - ([1, 1], [2, 2]), - ), - ], -) -def test_AdvancedIncSubtensor(x, y, indices): - out_pt = pt.set_subtensor(x[indices], y) - assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([], [out_pt]) - compare_numba_and_py(out_fg, []) - - out_pt = pt.inc_subtensor(x[indices], y) - assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([], [out_pt]) - compare_numba_and_py(out_fg, []) - - x_pt = x.type() - out_pt = pt.set_subtensor(x_pt[indices], y) - # Inplace isn't really implemented for `AdvancedIncSubtensor`, so we just - # hack it on here - out_pt.owner.op.inplace = True - assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_numba_and_py(out_fg, [x.data]) - - @pytest.mark.parametrize( "x, i", [ diff --git a/tests/link/numba/test_subtensor.py b/tests/link/numba/test_subtensor.py new file mode 100644 index 0000000000..87f1300bfb --- /dev/null +++ b/tests/link/numba/test_subtensor.py @@ -0,0 +1,234 @@ +import numpy as np +import pytest + +from pytensor.graph import FunctionGraph +from pytensor.tensor import as_tensor +from pytensor.tensor.subtensor import ( + AdvancedIncSubtensor, + AdvancedIncSubtensor1, + AdvancedSubtensor, + AdvancedSubtensor1, + IncSubtensor, + Subtensor, + advanced_inc_subtensor1, + advanced_set_subtensor1, + advanced_subtensor1, + inc_subtensor, + set_subtensor, +) +from tests.link.numba.test_basic import compare_numba_and_py + + +rng = np.random.default_rng(sum(map(ord, "Numba subtensors"))) + + +@pytest.mark.parametrize( + "x, indices", + [ + (as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), (1,)), + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + (slice(None)), + ), + (as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), (1, 2, 0)), + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + (slice(1, 2), 1, slice(None)), + ), + ], +) +def test_Subtensor(x, indices): + """Test NumPy's basic indexing.""" + out_pt = x[indices] + assert isinstance(out_pt.owner.op, Subtensor) + out_fg = FunctionGraph([], [out_pt]) + compare_numba_and_py(out_fg, []) + + +@pytest.mark.parametrize( + "x, indices", + [ + (as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2],)), + (as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1],)), + ], +) +def test_AdvancedSubtensor1(x, indices): + """Test NumPy's advanced indexing in one dimension.""" + out_pt = advanced_subtensor1(x, *indices) + assert isinstance(out_pt.owner.op, AdvancedSubtensor1) + out_fg = FunctionGraph([], [out_pt]) + compare_numba_and_py(out_fg, []) + + +def test_AdvancedSubtensor1_out_of_bounds(): + out_pt = advanced_subtensor1(np.arange(3), [4]) + assert isinstance(out_pt.owner.op, AdvancedSubtensor1) + out_fg = FunctionGraph([], [out_pt]) + with pytest.raises(IndexError): + compare_numba_and_py(out_fg, []) + + +@pytest.mark.parametrize( + "x, indices", + [ + (as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3])), + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + ([1, 2], slice(None), [3, 4]), + ), + ], +) +def test_AdvancedSubtensor(x, indices): + """Test NumPy's advanced indexing in more than one dimension.""" + out_pt = x[indices] + assert isinstance(out_pt.owner.op, AdvancedSubtensor) + out_fg = FunctionGraph([], [out_pt]) + compare_numba_and_py(out_fg, []) + + +@pytest.mark.parametrize( + "x, y, indices", + [ + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + as_tensor(np.array(10)), + (1,), + ), + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + as_tensor(rng.poisson(size=(4, 5))), + (slice(None)), + ), + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + as_tensor(np.array(10)), + (1, 2, 0), + ), + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + as_tensor(rng.poisson(size=(1, 5))), + (slice(1, 2), 1, slice(None)), + ), + ], +) +def test_IncSubtensor(x, y, indices): + out_pt = set_subtensor(x[indices], y) + assert isinstance(out_pt.owner.op, IncSubtensor) + out_fg = FunctionGraph([], [out_pt]) + compare_numba_and_py(out_fg, []) + + out_pt = inc_subtensor(x[indices], y) + assert isinstance(out_pt.owner.op, IncSubtensor) + out_fg = FunctionGraph([], [out_pt]) + compare_numba_and_py(out_fg, []) + + x_pt = x.type() + out_pt = set_subtensor(x_pt[indices], y, inplace=True) + assert isinstance(out_pt.owner.op, IncSubtensor) + out_fg = FunctionGraph([x_pt], [out_pt]) + compare_numba_and_py(out_fg, [x.data]) + + +@pytest.mark.parametrize( + "x, y, indices", + [ + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + as_tensor(rng.poisson(size=(2, 4, 5))), + ([1, 2],), + ), + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + as_tensor(rng.poisson(size=(2, 4, 5))), + ([1, 1],), + ), + # Broadcasting values + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + as_tensor(rng.poisson(size=(1, 4, 5))), + ([0, 2, 0],), + ), + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + as_tensor(rng.poisson(size=(5,))), + ([0, 2],), + ), + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + as_tensor(rng.poisson(size=())), + ([2, 0],), + ), + ( + as_tensor(np.arange(5)), + as_tensor(rng.poisson(size=())), + ([2, 0],), + ), + ], +) +def test_AdvancedIncSubtensor1(x, y, indices): + out_pt = advanced_set_subtensor1(x, y, *indices) + assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1) + out_fg = FunctionGraph([], [out_pt]) + compare_numba_and_py(out_fg, []) + + out_pt = advanced_inc_subtensor1(x, y, *indices) + assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1) + out_fg = FunctionGraph([], [out_pt]) + compare_numba_and_py(out_fg, []) + + # With symbolic inputs + x_pt = x.type() + y_pt = y.type() + + out_pt = AdvancedIncSubtensor1(inplace=True)(x_pt, y_pt, *indices) + assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1) + out_fg = FunctionGraph([x_pt, y_pt], [out_pt]) + compare_numba_and_py(out_fg, [x.data, y.data]) + + out_pt = AdvancedIncSubtensor1(set_instead_of_inc=True, inplace=True)( + x_pt, y_pt, *indices + ) + assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1) + out_fg = FunctionGraph([x_pt, y_pt], [out_pt]) + compare_numba_and_py(out_fg, [x.data, y.data]) + + +@pytest.mark.parametrize( + "x, y, indices", + [ + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + as_tensor(rng.poisson(size=(2, 5))), + ([1, 2], [2, 3]), + ), + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + as_tensor(rng.poisson(size=(2, 4))), + ([1, 2], slice(None), [3, 4]), + ), + pytest.param( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + as_tensor(rng.poisson(size=(2, 5))), + ([1, 1], [2, 2]), + ), + ], +) +def test_AdvancedIncSubtensor(x, y, indices): + out_pt = set_subtensor(x[indices], y) + assert isinstance(out_pt.owner.op, AdvancedIncSubtensor) + out_fg = FunctionGraph([], [out_pt]) + compare_numba_and_py(out_fg, []) + + out_pt = inc_subtensor(x[indices], y) + assert isinstance(out_pt.owner.op, AdvancedIncSubtensor) + out_fg = FunctionGraph([], [out_pt]) + compare_numba_and_py(out_fg, []) + + x_pt = x.type() + out_pt = set_subtensor(x_pt[indices], y) + # Inplace isn't really implemented for `AdvancedIncSubtensor`, so we just + # hack it on here + out_pt.owner.op.inplace = True + assert isinstance(out_pt.owner.op, AdvancedIncSubtensor) + out_fg = FunctionGraph([x_pt], [out_pt]) + compare_numba_and_py(out_fg, [x.data]) From b6ee8e7333b4fbcca425f9bb2fc00334359e159b Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 5 Jun 2024 17:37:56 +0200 Subject: [PATCH 2/2] Do not use Numba objmode for supported AdvancedSubtensor operations Use ScalarTypes in MakeSlice for compatibility with Numba --- pytensor/link/numba/dispatch/subtensor.py | 113 +++++++---------- pytensor/tensor/subtensor.py | 19 ++- pytensor/tensor/type_other.py | 2 +- tests/link/numba/test_subtensor.py | 145 ++++++++++++++++++++-- 4 files changed, 193 insertions(+), 86 deletions(-) diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index 3d2f3f2901..178ce0b857 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -1,12 +1,10 @@ -import warnings - -import numba import numpy as np from pytensor.graph import Type from pytensor.link.numba.dispatch import numba_funcify -from pytensor.link.numba.dispatch.basic import numba_njit +from pytensor.link.numba.dispatch.basic import generate_fallback_impl, numba_njit from pytensor.link.utils import compile_function_src, unique_name_generator +from pytensor.tensor import TensorType from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, AdvancedIncSubtensor1, @@ -17,7 +15,10 @@ ) -def create_index_func(node, objmode=False): +@numba_funcify.register(Subtensor) +@numba_funcify.register(IncSubtensor) +@numba_funcify.register(AdvancedSubtensor1) +def numba_funcify_default_subtensor(op, node, **kwargs): """Create a Python function that assembles and uses an index on an array.""" unique_names = unique_name_generator( @@ -40,13 +41,13 @@ def convert_indices(indices, entry): raise ValueError() set_or_inc = isinstance( - node.op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor + op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor ) index_start_idx = 1 + int(set_or_inc) input_names = [unique_names(v, force_unique=True) for v in node.inputs] op_indices = list(node.inputs[index_start_idx:]) - idx_list = getattr(node.op, "idx_list", None) + idx_list = getattr(op, "idx_list", None) indices_creation_src = ( tuple(convert_indices(op_indices, idx) for idx in idx_list) @@ -61,8 +62,7 @@ def convert_indices(indices, entry): indices_creation_src = f"indices = ({indices_creation_src})" if set_or_inc: - fn_name = "incsubtensor" - if node.op.inplace: + if op.inplace: index_prologue = f"z = {input_names[0]}" else: index_prologue = f"z = np.copy({input_names[0]})" @@ -74,84 +74,57 @@ def convert_indices(indices, entry): else: y_name = input_names[1] - if node.op.set_instead_of_inc: + if op.set_instead_of_inc: + function_name = "setsubtensor" index_body = f"z[indices] = {y_name}" else: + function_name = "incsubtensor" index_body = f"z[indices] += {y_name}" else: - fn_name = "subtensor" + function_name = "subtensor" index_prologue = "" index_body = f"z = {input_names[0]}[indices]" - if objmode: - output_var = node.outputs[0] - - if not set_or_inc: - # Since `z` is being "created" while in object mode, it's - # considered an "outgoing" variable and needs to be manually typed - output_sig = f"z='{output_var.dtype}[{', '.join([':'] * output_var.ndim)}]'" - else: - output_sig = "" - - index_body = f""" - with objmode({output_sig}): - {index_body} - """ - subtensor_def_src = f""" -def {fn_name}({", ".join(input_names)}): +def {function_name}({", ".join(input_names)}): {index_prologue} {indices_creation_src} {index_body} return np.asarray(z) """ - return subtensor_def_src - - -@numba_funcify.register(Subtensor) -@numba_funcify.register(AdvancedSubtensor1) -def numba_funcify_Subtensor(op, node, **kwargs): - objmode = isinstance(op, AdvancedSubtensor) - if objmode: - warnings.warn( - ("Numba will use object mode to allow run " "AdvancedSubtensor."), - UserWarning, - ) - - subtensor_def_src = create_index_func(node, objmode=objmode) - - global_env = {"np": np} - if objmode: - global_env["objmode"] = numba.objmode - - subtensor_fn = compile_function_src( - subtensor_def_src, "subtensor", {**globals(), **global_env} + func = compile_function_src( + subtensor_def_src, + function_name=function_name, + global_env=globals() | {"np": np}, ) - - return numba_njit(subtensor_fn, boundscheck=True) - - -@numba_funcify.register(IncSubtensor) -def numba_funcify_IncSubtensor(op, node, **kwargs): - objmode = isinstance(op, AdvancedIncSubtensor) - if objmode: - warnings.warn( - ("Numba will use object mode to allow run " "AdvancedIncSubtensor."), - UserWarning, + return numba_njit(func, boundscheck=True) + + +@numba_funcify.register(AdvancedSubtensor) +@numba_funcify.register(AdvancedIncSubtensor) +def numba_funcify_AdvancedSubtensor(op, node, **kwargs): + idxs = node.inputs[1:] if isinstance(op, AdvancedSubtensor) else node.inputs[2:] + adv_idxs_dims = [ + idx.type.ndim + for idx in idxs + if (isinstance(idx.type, TensorType) and idx.type.ndim > 0) + ] + + if ( + # Numba does not support indexes with more than one dimension + # Nor multiple vector indexes + (len(adv_idxs_dims) > 1 or adv_idxs_dims[0] > 1) + # The default index implementation does not handle duplicate indices correctly + or ( + isinstance(op, AdvancedIncSubtensor) + and not op.set_instead_of_inc + and not op.ignore_duplicates ) + ): + return generate_fallback_impl(op, node, **kwargs) - incsubtensor_def_src = create_index_func(node, objmode=objmode) - - global_env = {"np": np} - if objmode: - global_env["objmode"] = numba.objmode - - incsubtensor_fn = compile_function_src( - incsubtensor_def_src, "incsubtensor", {**globals(), **global_env} - ) - - return numba_njit(incsubtensor_fn, boundscheck=True) + return numba_funcify_default_subtensor(op, node, **kwargs) @numba_funcify.register(AdvancedIncSubtensor1) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index c89049105f..95fd60a97c 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -21,7 +21,12 @@ from pytensor.printing import Printer, pprint, set_precedence from pytensor.scalar.basic import ScalarConstant from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length -from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value, nonzero +from pytensor.tensor.basic import ( + ScalarFromTensor, + alloc, + get_underlying_scalar_constant_value, + nonzero, +) from pytensor.tensor.blockwise import vectorize_node_fallback from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError @@ -168,8 +173,16 @@ def as_index_literal( if isinstance(idx, Constant): return idx.data.item() if isinstance(idx, np.ndarray) else idx.data - if isinstance(getattr(idx, "type", None), SliceType): - idx = slice(*idx.owner.inputs) + if isinstance(idx, Variable): + if ( + isinstance(idx.type, ps.ScalarType) + and idx.owner + and isinstance(idx.owner.op, ScalarFromTensor) + ): + return as_index_literal(idx.owner.inputs[0]) + + if isinstance(idx.type, SliceType): + idx = slice(*idx.owner.inputs) if isinstance(idx, slice): return slice( diff --git a/pytensor/tensor/type_other.py b/pytensor/tensor/type_other.py index 5704b43859..3344a44d94 100644 --- a/pytensor/tensor/type_other.py +++ b/pytensor/tensor/type_other.py @@ -18,7 +18,7 @@ def as_int_none_variable(x): return NoneConst elif NoneConst.equals(x): return x - x = pytensor.tensor.as_tensor_variable(x, ndim=0) + x = pytensor.scalar.as_scalar(x) if x.type.dtype not in integer_dtypes: raise TypeError("index must be integers") return x diff --git a/tests/link/numba/test_subtensor.py b/tests/link/numba/test_subtensor.py index 87f1300bfb..5e1784f368 100644 --- a/tests/link/numba/test_subtensor.py +++ b/tests/link/numba/test_subtensor.py @@ -1,6 +1,9 @@ +import contextlib + import numpy as np import pytest +import pytensor.tensor as pt from pytensor.graph import FunctionGraph from pytensor.tensor import as_tensor from pytensor.tensor.subtensor import ( @@ -48,8 +51,8 @@ def test_Subtensor(x, indices): @pytest.mark.parametrize( "x, indices", [ - (as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2],)), - (as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1],)), + (pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2],)), + (pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1],)), ], ) def test_AdvancedSubtensor1(x, indices): @@ -69,21 +72,46 @@ def test_AdvancedSubtensor1_out_of_bounds(): @pytest.mark.parametrize( - "x, indices", + "x, indices, objmode_needed", [ - (as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3])), + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + (0, [1, 2, 2, 3]), + False, + ), + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + (np.array([True, False, False])), + False, + ), + ( + as_tensor(np.arange(3 * 3).reshape((3, 3))), + (np.eye(3).astype(bool)), + True, + ), + (pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3]), True), ( as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], slice(None), [3, 4]), + True, ), ], ) -def test_AdvancedSubtensor(x, indices): +@pytest.mark.filterwarnings("error") +def test_AdvancedSubtensor(x, indices, objmode_needed): """Test NumPy's advanced indexing in more than one dimension.""" out_pt = x[indices] assert isinstance(out_pt.owner.op, AdvancedSubtensor) out_fg = FunctionGraph([], [out_pt]) - compare_numba_and_py(out_fg, []) + with ( + pytest.warns( + UserWarning, + match="Numba will use object mode to run AdvancedSubtensor's perform method", + ) + if objmode_needed + else contextlib.nullcontext() + ): + compare_numba_and_py(out_fg, []) @pytest.mark.parametrize( @@ -194,35 +222,120 @@ def test_AdvancedIncSubtensor1(x, y, indices): @pytest.mark.parametrize( - "x, y, indices", + "x, y, indices, duplicate_indices, set_requires_objmode, inc_requires_objmode", [ + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + -np.arange(3 * 5).reshape(3, 5), + (slice(None, None, 2), [1, 2, 3]), + False, + False, + False, + ), + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + -99, + (slice(None, None, 2), [1, 2, 3], -1), + False, + False, + False, + ), + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + -99, # Broadcasted value + (slice(None, None, 2), [1, 2, 3]), + False, + False, + False, + ), + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + -np.arange(4 * 5).reshape(4, 5), + (0, [1, 2, 2, 3]), + True, + False, + True, + ), + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + [-99], # Broadcsasted value + (0, [1, 2, 2, 3]), + True, + False, + True, + ), + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + -np.arange(1 * 4 * 5).reshape(1, 4, 5), + (np.array([True, False, False])), + False, + False, + False, + ), + ( + as_tensor(np.arange(3 * 3).reshape((3, 3))), + -np.arange(3), + (np.eye(3).astype(bool)), + False, + True, + True, + ), ( as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), as_tensor(rng.poisson(size=(2, 5))), ([1, 2], [2, 3]), + False, + True, + True, ), ( as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), as_tensor(rng.poisson(size=(2, 4))), ([1, 2], slice(None), [3, 4]), + False, + True, + True, ), pytest.param( as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), as_tensor(rng.poisson(size=(2, 5))), ([1, 1], [2, 2]), + False, + True, + True, ), ], ) -def test_AdvancedIncSubtensor(x, y, indices): +@pytest.mark.filterwarnings("error") +def test_AdvancedIncSubtensor( + x, y, indices, duplicate_indices, set_requires_objmode, inc_requires_objmode +): out_pt = set_subtensor(x[indices], y) assert isinstance(out_pt.owner.op, AdvancedIncSubtensor) out_fg = FunctionGraph([], [out_pt]) - compare_numba_and_py(out_fg, []) - out_pt = inc_subtensor(x[indices], y) + with ( + pytest.warns( + UserWarning, + match="Numba will use object mode to run AdvancedSetSubtensor's perform method", + ) + if set_requires_objmode + else contextlib.nullcontext() + ): + compare_numba_and_py(out_fg, []) + + out_pt = inc_subtensor(x[indices], y, ignore_duplicates=not duplicate_indices) assert isinstance(out_pt.owner.op, AdvancedIncSubtensor) out_fg = FunctionGraph([], [out_pt]) - compare_numba_and_py(out_fg, []) + with ( + pytest.warns( + UserWarning, + match="Numba will use object mode to run AdvancedIncSubtensor's perform method", + ) + if inc_requires_objmode + else contextlib.nullcontext() + ): + compare_numba_and_py(out_fg, []) x_pt = x.type() out_pt = set_subtensor(x_pt[indices], y) @@ -231,4 +344,12 @@ def test_AdvancedIncSubtensor(x, y, indices): out_pt.owner.op.inplace = True assert isinstance(out_pt.owner.op, AdvancedIncSubtensor) out_fg = FunctionGraph([x_pt], [out_pt]) - compare_numba_and_py(out_fg, [x.data]) + with ( + pytest.warns( + UserWarning, + match="Numba will use object mode to run AdvancedSetSubtensor's perform method", + ) + if set_requires_objmode + else contextlib.nullcontext() + ): + compare_numba_and_py(out_fg, [x.data])