From c9cfc2c9193fcdf0e52a2bbdace54182780839c9 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 1 Jun 2025 11:37:49 +0200 Subject: [PATCH 1/5] TST: add a test that wrapping preserves a view/copy semantics for unary functions If a bare library returns a copy, so does the wrapped library; if the bare library returns a view, so does the wrapped library. --- tests/test_copies_or_views.py | 66 +++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 tests/test_copies_or_views.py diff --git a/tests/test_copies_or_views.py b/tests/test_copies_or_views.py new file mode 100644 index 00000000..5b9b9207 --- /dev/null +++ b/tests/test_copies_or_views.py @@ -0,0 +1,66 @@ +""" +A collection of tests to make sure that wrapped namespaces agree with the bare ones +on whether to return a view or a copy of inputs. +""" +import pytest +from ._helpers import import_ + + +LIB_NAMES = ['numpy', 'cupy', 'torch', 'dask.array', 'array_api_strict'] + +FUNC_INPUTS = [ + # func_name, arr_input, dtype, scalar_value + ('abs', [1, 2], 'int8', 3), + ('abs', [1, 2], 'float32', 3.), + ('ceil', [1, 2], 'int8', 3), + ('clip', [1, 2], 'int8', 3), + ('conj', [1, 2], 'int8', 3), + ('floor', [1, 2], 'int8', 3), + ('imag', [1j, 2j], 'complex64', 3), + ('positive', [1, 2], 'int8', 3), + ('real', [1., 2.], 'float32', 3.), + ('round', [1, 2], 'int8', 3), + ('sign', [0, 0], 'float32', 3), + ('trunc', [1, 2], 'int8', 3), + ('trunc', [1, 2], 'float32', 3), +] + + +def ensure_unary(func, arr): + """Make a trivial unary function from func.""" + if func.__name__ == 'clip': + return lambda x: func(x, arr[0], arr[1]) + return func + + +def is_view(func, a, value): + """Apply `func`, mutate the output; does the input change?""" + b = func(a) + b[0] = value + return a[0] == value + + +@pytest.mark.parametrize('xp_name', LIB_NAMES) +@pytest.mark.parametrize('inputs', FUNC_INPUTS, ids=[inp[0] for inp in FUNC_INPUTS]) +def test_view_or_copy(inputs, xp_name): + bare_xp = import_(xp_name, wrapper=False) + wrapped_xp = import_(xp_name, wrapper=True) + + func_name, arr_input, dtype_str, value = inputs + dtype = getattr(bare_xp, dtype_str) + + bare_func = getattr(bare_xp, func_name) + bare_func = ensure_unary(bare_func, arr_input) + + wrapped_func = getattr(wrapped_xp, func_name) + wrapped_func = ensure_unary(wrapped_func, arr_input) + + # bare namespace: mutate the output, does the input change? + a = bare_xp.asarray(arr_input, dtype=dtype) + is_view_bare = is_view(bare_func, a, value) + + # wrapped namespace: mutate the output, does the input change? + a1 = wrapped_xp.asarray(arr_input, dtype=dtype) + is_view_wrapped = is_view(wrapped_func, a1, value) + + assert is_view_bare == is_view_wrapped From 1facc3526414926b2d123e88c16f7d517d9d2558 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 1 Jun 2025 14:30:54 +0200 Subject: [PATCH 2/5] BUG: make ceil,trunc,floor always respect view/copy semantics Remove these functions from common/_aliases.py, add specific implementations for numpy < 2 and cupy. --- array_api_compat/common/_aliases.py | 24 -------------------- array_api_compat/cupy/_aliases.py | 24 ++++++++++++++++---- array_api_compat/dask/array/_aliases.py | 3 --- array_api_compat/numpy/_aliases.py | 29 ++++++++++++++++++++++--- 4 files changed, 46 insertions(+), 34 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 8ea9162a..39d10860 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -524,27 +524,6 @@ def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]: return xp.nonzero(x, **kwargs) -# ceil, floor, and trunc return integers for integer inputs - - -def ceil(x: Array, /, xp: Namespace, **kwargs: object) -> Array: - if xp.issubdtype(x.dtype, xp.integer): - return x - return xp.ceil(x, **kwargs) - - -def floor(x: Array, /, xp: Namespace, **kwargs: object) -> Array: - if xp.issubdtype(x.dtype, xp.integer): - return x - return xp.floor(x, **kwargs) - - -def trunc(x: Array, /, xp: Namespace, **kwargs: object) -> Array: - if xp.issubdtype(x.dtype, xp.integer): - return x - return xp.trunc(x, **kwargs) - - # linear algebra functions @@ -707,9 +686,6 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any: "argsort", "sort", "nonzero", - "ceil", - "floor", - "trunc", "matmul", "matrix_transpose", "tensordot", diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 90b48f05..e000602e 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -54,9 +54,6 @@ argsort = get_xp(cp)(_aliases.argsort) sort = get_xp(cp)(_aliases.sort) nonzero = get_xp(cp)(_aliases.nonzero) -ceil = get_xp(cp)(_aliases.ceil) -floor = get_xp(cp)(_aliases.floor) -trunc = get_xp(cp)(_aliases.trunc) matmul = get_xp(cp)(_aliases.matmul) matrix_transpose = get_xp(cp)(_aliases.matrix_transpose) tensordot = get_xp(cp)(_aliases.tensordot) @@ -123,6 +120,25 @@ def count_nonzero( return cp.expand_dims(result, axis) return result +# ceil, floor, and trunc return integers for integer inputs + +def ceil(x: Array, /) -> Array: + if cp.issubdtype(x.dtype, cp.integer): + return x.copy() + return cp.ceil(x) + + +def floor(x: Array, /) -> Array: + if cp.issubdtype(x.dtype, cp.integer): + return x.copy() + return cp.floor(x) + + +def trunc(x: Array, /) -> Array: + if cp.issubdtype(x.dtype, cp.integer): + return x.copy() + return cp.trunc(x) + # take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): @@ -151,6 +167,6 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', 'bool', 'concat', 'count_nonzero', 'pow', 'sign', - 'take_along_axis'] + 'ceil', 'floor', 'trunc', 'take_along_axis'] _all_ignore = ['cp', 'get_xp'] diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index d43881ab..0bb5d227 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -134,9 +134,6 @@ def arange( matrix_transpose = get_xp(da)(_aliases.matrix_transpose) vecdot = get_xp(da)(_aliases.vecdot) nonzero = get_xp(da)(_aliases.nonzero) -ceil = get_xp(np)(_aliases.ceil) -floor = get_xp(np)(_aliases.floor) -trunc = get_xp(np)(_aliases.trunc) matmul = get_xp(np)(_aliases.matmul) tensordot = get_xp(np)(_aliases.tensordot) sign = get_xp(np)(_aliases.sign) diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index a1aee5c0..502dfb3a 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -63,9 +63,6 @@ argsort = get_xp(np)(_aliases.argsort) sort = get_xp(np)(_aliases.sort) nonzero = get_xp(np)(_aliases.nonzero) -ceil = get_xp(np)(_aliases.ceil) -floor = get_xp(np)(_aliases.floor) -trunc = get_xp(np)(_aliases.trunc) matmul = get_xp(np)(_aliases.matmul) matrix_transpose = get_xp(np)(_aliases.matrix_transpose) tensordot = get_xp(np)(_aliases.tensordot) @@ -145,6 +142,29 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): return np.take_along_axis(x, indices, axis=axis) +# ceil, floor, and trunc return integers for integer inputs in NumPy < 2 + +def ceil(x: Array, /) -> Array: + if np.issubdtype(x.dtype, np.integer): + if np.__version__ < '2': + return x.copy() + return np.ceil(x) + + +def floor(x: Array, /) -> Array: + if np.issubdtype(x.dtype, np.integer): + if np.__version__ < '2': + return x.copy() + return np.floor(x) + + +def trunc(x: Array, /) -> Array: + if np.issubdtype(x.dtype, np.integer): + if np.__version__ < '2': + return x.copy() + return np.trunc(x) + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(np, "vecdot"): @@ -173,6 +193,9 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): "atan", "atan2", "atanh", + "ceil", + "floor", + "trunc", "bitwise_left_shift", "bitwise_invert", "bitwise_right_shift", From 0ad664bdfde03ec3f21d82b1048616ae5d0fb6b7 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 4 Jun 2025 10:49:43 +0200 Subject: [PATCH 3/5] Apply suggestions from code review Co-authored-by: Guido Imperiale --- tests/test_copies_or_views.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_copies_or_views.py b/tests/test_copies_or_views.py index 5b9b9207..24d03547 100644 --- a/tests/test_copies_or_views.py +++ b/tests/test_copies_or_views.py @@ -3,7 +3,7 @@ on whether to return a view or a copy of inputs. """ import pytest -from ._helpers import import_ +from ._helpers import import_, wrapped_libraries LIB_NAMES = ['numpy', 'cupy', 'torch', 'dask.array', 'array_api_strict'] @@ -40,7 +40,7 @@ def is_view(func, a, value): return a[0] == value -@pytest.mark.parametrize('xp_name', LIB_NAMES) +@pytest.mark.parametrize('xp_name', wrapped_libraries) @pytest.mark.parametrize('inputs', FUNC_INPUTS, ids=[inp[0] for inp in FUNC_INPUTS]) def test_view_or_copy(inputs, xp_name): bare_xp = import_(xp_name, wrapper=False) From 118ae2d0428be763abf1e31b2827a4800398e901 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 4 Jun 2025 10:52:46 +0200 Subject: [PATCH 4/5] TST: test views vs copies on array-api-strict, too --- tests/test_copies_or_views.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_copies_or_views.py b/tests/test_copies_or_views.py index 24d03547..ec8995f7 100644 --- a/tests/test_copies_or_views.py +++ b/tests/test_copies_or_views.py @@ -6,8 +6,6 @@ from ._helpers import import_, wrapped_libraries -LIB_NAMES = ['numpy', 'cupy', 'torch', 'dask.array', 'array_api_strict'] - FUNC_INPUTS = [ # func_name, arr_input, dtype, scalar_value ('abs', [1, 2], 'int8', 3), @@ -40,7 +38,7 @@ def is_view(func, a, value): return a[0] == value -@pytest.mark.parametrize('xp_name', wrapped_libraries) +@pytest.mark.parametrize('xp_name', wrapped_libraries + ['array_api_strict']) @pytest.mark.parametrize('inputs', FUNC_INPUTS, ids=[inp[0] for inp in FUNC_INPUTS]) def test_view_or_copy(inputs, xp_name): bare_xp = import_(xp_name, wrapper=False) From b0eed557d6dba8c87d9693ff82360b33c1af3480 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 4 Jun 2025 13:05:08 +0200 Subject: [PATCH 5/5] Apply suggestions from code review Co-authored-by: Guido Imperiale --- array_api_compat/numpy/_aliases.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 502dfb3a..f04837de 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -145,23 +145,20 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): # ceil, floor, and trunc return integers for integer inputs in NumPy < 2 def ceil(x: Array, /) -> Array: - if np.issubdtype(x.dtype, np.integer): - if np.__version__ < '2': - return x.copy() + if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer): + return x.copy() return np.ceil(x) def floor(x: Array, /) -> Array: - if np.issubdtype(x.dtype, np.integer): - if np.__version__ < '2': - return x.copy() + if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer): + return x.copy() return np.floor(x) def trunc(x: Array, /) -> Array: - if np.issubdtype(x.dtype, np.integer): - if np.__version__ < '2': - return x.copy() + if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer): + return x.copy() return np.trunc(x)