Skip to content

Commit 7d7c5b8

Browse files
committed
BUG: add copy parameter for api.reshape function
This adds a parameter to api.reshape to specify if data should be copied. This parameter is required so that api.reshape conforms to the standard. See #23410 Original NumPy Commit: c19e84e012da65828f62f54f721337d873fd04fe
1 parent 9899b78 commit 7d7c5b8

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

array_api_strict/_manipulation_functions.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,24 @@ def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
5353

5454

5555
# Note: the optional argument is called 'shape', not 'newshape'
56-
def reshape(x: Array, /, shape: Tuple[int, ...]) -> Array:
56+
def reshape(x: Array,
57+
/,
58+
shape: Tuple[int, ...],
59+
*,
60+
copy: Optional[Bool] = None) -> Array:
5761
"""
5862
Array API compatible wrapper for :py:func:`np.reshape <numpy.reshape>`.
5963
6064
See its docstring for more information.
6165
"""
62-
return Array._new(np.reshape(x._array, shape))
66+
if copy is False:
67+
raise NotImplementedError("copy=False is not yet implemented")
68+
69+
data = x._array
70+
if copy:
71+
data = np.copy(data)
72+
73+
return Array._new(np.reshape(data, shape))
6374

6475

6576
def roll(
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from numpy.testing import assert_raises
2+
import numpy as np
3+
4+
from numpy. import all
5+
from numpy._creation_functions import asarray
6+
from numpy._dtypes import float64, int8
7+
from numpy._manipulation_functions import (
8+
concat,
9+
reshape,
10+
stack
11+
)
12+
13+
14+
def test_concat_errors():
15+
assert_raises(TypeError, lambda: concat((1, 1), axis=None))
16+
assert_raises(TypeError, lambda: concat([asarray([1], dtype=int8), asarray([1], dtype=float64)]))
17+
18+
19+
def test_stack_errors():
20+
assert_raises(TypeError, lambda: stack([asarray([1, 1], dtype=int8), asarray([2, 2], dtype=float64)]))
21+
22+
23+
def test_reshape_copy():
24+
a = asarray([1])
25+
b = reshape(a, (1, 1), copy=True)
26+
a[0] = 0
27+
assert all(b[0, 0] == 1)
28+
assert all(a[0] == 0)
29+
assert_raises(NotImplementedError, lambda: reshape(a, (1, 1), copy=False))
30+

0 commit comments

Comments
 (0)