Skip to content

Commit fd01e63

Browse files
authored
Merge pull request #333 from ev-br/views_vs_copies
TST: add test that wrapping preserves view/copy semantics
2 parents cddc9ef + b0eed55 commit fd01e63

File tree

5 files changed

+107
-34
lines changed

5 files changed

+107
-34
lines changed

array_api_compat/common/_aliases.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -525,27 +525,6 @@ def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]:
525525
return xp.nonzero(x, **kwargs)
526526

527527

528-
# ceil, floor, and trunc return integers for integer inputs
529-
530-
531-
def ceil(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
532-
if xp.issubdtype(x.dtype, xp.integer):
533-
return x
534-
return xp.ceil(x, **kwargs)
535-
536-
537-
def floor(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
538-
if xp.issubdtype(x.dtype, xp.integer):
539-
return x
540-
return xp.floor(x, **kwargs)
541-
542-
543-
def trunc(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
544-
if xp.issubdtype(x.dtype, xp.integer):
545-
return x
546-
return xp.trunc(x, **kwargs)
547-
548-
549528
# linear algebra functions
550529

551530

@@ -708,9 +687,6 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any:
708687
"argsort",
709688
"sort",
710689
"nonzero",
711-
"ceil",
712-
"floor",
713-
"trunc",
714690
"matmul",
715691
"matrix_transpose",
716692
"tensordot",

array_api_compat/cupy/_aliases.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,6 @@
5353
argsort = get_xp(cp)(_aliases.argsort)
5454
sort = get_xp(cp)(_aliases.sort)
5555
nonzero = get_xp(cp)(_aliases.nonzero)
56-
ceil = get_xp(cp)(_aliases.ceil)
57-
floor = get_xp(cp)(_aliases.floor)
58-
trunc = get_xp(cp)(_aliases.trunc)
5956
matmul = get_xp(cp)(_aliases.matmul)
6057
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
6158
tensordot = get_xp(cp)(_aliases.tensordot)
@@ -117,6 +114,25 @@ def count_nonzero(
117114
return cp.expand_dims(result, axis)
118115
return result
119116

117+
# ceil, floor, and trunc return integers for integer inputs
118+
119+
def ceil(x: Array, /) -> Array:
120+
if cp.issubdtype(x.dtype, cp.integer):
121+
return x.copy()
122+
return cp.ceil(x)
123+
124+
125+
def floor(x: Array, /) -> Array:
126+
if cp.issubdtype(x.dtype, cp.integer):
127+
return x.copy()
128+
return cp.floor(x)
129+
130+
131+
def trunc(x: Array, /) -> Array:
132+
if cp.issubdtype(x.dtype, cp.integer):
133+
return x.copy()
134+
return cp.trunc(x)
135+
120136

121137
# take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg
122138
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
@@ -145,7 +161,7 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
145161
'atan2', 'atanh', 'bitwise_left_shift',
146162
'bitwise_invert', 'bitwise_right_shift',
147163
'bool', 'concat', 'count_nonzero', 'pow', 'sign',
148-
'take_along_axis']
164+
'ceil', 'floor', 'trunc', 'take_along_axis']
149165

150166

151167
def __dir__() -> list[str]:

array_api_compat/dask/array/_aliases.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,6 @@ def arange(
133133
matrix_transpose = get_xp(da)(_aliases.matrix_transpose)
134134
vecdot = get_xp(da)(_aliases.vecdot)
135135
nonzero = get_xp(da)(_aliases.nonzero)
136-
ceil = get_xp(np)(_aliases.ceil)
137-
floor = get_xp(np)(_aliases.floor)
138-
trunc = get_xp(np)(_aliases.trunc)
139136
matmul = get_xp(np)(_aliases.matmul)
140137
tensordot = get_xp(np)(_aliases.tensordot)
141138
sign = get_xp(np)(_aliases.sign)

array_api_compat/numpy/_aliases.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,6 @@
5555
argsort = get_xp(np)(_aliases.argsort)
5656
sort = get_xp(np)(_aliases.sort)
5757
nonzero = get_xp(np)(_aliases.nonzero)
58-
ceil = get_xp(np)(_aliases.ceil)
59-
floor = get_xp(np)(_aliases.floor)
60-
trunc = get_xp(np)(_aliases.trunc)
6158
matmul = get_xp(np)(_aliases.matmul)
6259
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
6360
tensordot = get_xp(np)(_aliases.tensordot)
@@ -129,6 +126,26 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
129126
return np.take_along_axis(x, indices, axis=axis)
130127

131128

129+
# ceil, floor, and trunc return integers for integer inputs in NumPy < 2
130+
131+
def ceil(x: Array, /) -> Array:
132+
if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer):
133+
return x.copy()
134+
return np.ceil(x)
135+
136+
137+
def floor(x: Array, /) -> Array:
138+
if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer):
139+
return x.copy()
140+
return np.floor(x)
141+
142+
143+
def trunc(x: Array, /) -> Array:
144+
if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer):
145+
return x.copy()
146+
return np.trunc(x)
147+
148+
132149
# These functions are completely new here. If the library already has them
133150
# (i.e., numpy 2.0), use the library version instead of our wrapper.
134151
if hasattr(np, "vecdot"):
@@ -156,6 +173,9 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
156173
"atan",
157174
"atan2",
158175
"atanh",
176+
"ceil",
177+
"floor",
178+
"trunc",
159179
"bitwise_left_shift",
160180
"bitwise_invert",
161181
"bitwise_right_shift",

tests/test_copies_or_views.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""
2+
A collection of tests to make sure that wrapped namespaces agree with the bare ones
3+
on whether to return a view or a copy of inputs.
4+
"""
5+
import pytest
6+
from ._helpers import import_, wrapped_libraries
7+
8+
9+
FUNC_INPUTS = [
10+
# func_name, arr_input, dtype, scalar_value
11+
('abs', [1, 2], 'int8', 3),
12+
('abs', [1, 2], 'float32', 3.),
13+
('ceil', [1, 2], 'int8', 3),
14+
('clip', [1, 2], 'int8', 3),
15+
('conj', [1, 2], 'int8', 3),
16+
('floor', [1, 2], 'int8', 3),
17+
('imag', [1j, 2j], 'complex64', 3),
18+
('positive', [1, 2], 'int8', 3),
19+
('real', [1., 2.], 'float32', 3.),
20+
('round', [1, 2], 'int8', 3),
21+
('sign', [0, 0], 'float32', 3),
22+
('trunc', [1, 2], 'int8', 3),
23+
('trunc', [1, 2], 'float32', 3),
24+
]
25+
26+
27+
def ensure_unary(func, arr):
28+
"""Make a trivial unary function from func."""
29+
if func.__name__ == 'clip':
30+
return lambda x: func(x, arr[0], arr[1])
31+
return func
32+
33+
34+
def is_view(func, a, value):
35+
"""Apply `func`, mutate the output; does the input change?"""
36+
b = func(a)
37+
b[0] = value
38+
return a[0] == value
39+
40+
41+
@pytest.mark.parametrize('xp_name', wrapped_libraries + ['array_api_strict'])
42+
@pytest.mark.parametrize('inputs', FUNC_INPUTS, ids=[inp[0] for inp in FUNC_INPUTS])
43+
def test_view_or_copy(inputs, xp_name):
44+
bare_xp = import_(xp_name, wrapper=False)
45+
wrapped_xp = import_(xp_name, wrapper=True)
46+
47+
func_name, arr_input, dtype_str, value = inputs
48+
dtype = getattr(bare_xp, dtype_str)
49+
50+
bare_func = getattr(bare_xp, func_name)
51+
bare_func = ensure_unary(bare_func, arr_input)
52+
53+
wrapped_func = getattr(wrapped_xp, func_name)
54+
wrapped_func = ensure_unary(wrapped_func, arr_input)
55+
56+
# bare namespace: mutate the output, does the input change?
57+
a = bare_xp.asarray(arr_input, dtype=dtype)
58+
is_view_bare = is_view(bare_func, a, value)
59+
60+
# wrapped namespace: mutate the output, does the input change?
61+
a1 = wrapped_xp.asarray(arr_input, dtype=dtype)
62+
is_view_wrapped = is_view(wrapped_func, a1, value)
63+
64+
assert is_view_bare == is_view_wrapped

0 commit comments

Comments
 (0)