Skip to content

Commit a173ccc

Browse files
npolina4antonwolfy
andauthored
Added support dtype and out arguments for dpnp.concatenate and dpnp.stack functions (#1608)
* Added support dtype and out args for concatenate and out functions * address comments --------- Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com>
1 parent 83d28d7 commit a173ccc

File tree

3 files changed

+95
-96
lines changed

3 files changed

+95
-96
lines changed

dpnp/dpnp_iface_manipulation.py

Lines changed: 81 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -452,19 +452,29 @@ def concatenate(
452452
453453
For full documentation refer to :obj:`numpy.concatenate`.
454454
455+
Parameters
456+
----------
457+
arrays : {dpnp.ndarray, usm_ndarray}
458+
The arrays must have the same shape, except in the dimension corresponding
459+
to axis (the first, by default).
460+
axis : int, optional
461+
The axis along which the arrays will be joined. If axis is None, arrays are
462+
flattened before use. Default is 0.
463+
out : dpnp.ndarray, optional
464+
If provided, the destination to place the result. The shape must be correct,
465+
matching that of what concatenate would have returned if no out argument were
466+
specified.
467+
dtype : str or dtype
468+
If provided, the destination array will have this dtype. Cannot be provided
469+
together with out.
470+
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
471+
Controls what kind of data casting may occur. Defaults to 'same_kind'.
472+
455473
Returns
456474
-------
457475
out : dpnp.ndarray
458476
The concatenated array.
459477
460-
Limitations
461-
-----------
462-
Each array in `arrays` is supported as either :class:`dpnp.ndarray`
463-
or :class:`dpctl.tensor.usm_ndarray`. Otherwise ``TypeError`` exception
464-
will be raised.
465-
Parameters `out` and `dtype` are supported with default value.
466-
Otherwise the function will be executed sequentially on CPU.
467-
468478
See Also
469479
--------
470480
:obj:`dpnp.array_split` : Split an array into multiple sub-arrays of equal or near-equal size.
@@ -496,25 +506,20 @@ def concatenate(
496506
497507
"""
498508

499-
if out is not None:
500-
pass
501-
elif dtype is not None:
502-
pass
503-
elif casting != "same_kind":
504-
pass
505-
else:
506-
usm_arrays = [dpnp.get_usm_ndarray(x) for x in arrays]
507-
usm_res = dpt.concat(usm_arrays, axis=axis)
508-
return dpnp_array._create_from_usm_ndarray(usm_res)
509-
510-
return call_origin(
511-
numpy.concatenate,
512-
arrays,
513-
axis=axis,
514-
out=out,
515-
dtype=dtype,
516-
casting=casting,
517-
)
509+
if dtype is not None and out is not None:
510+
raise TypeError(
511+
"concatenate() only takes `out` or `dtype` as an argument, but both were provided."
512+
)
513+
514+
usm_arrays = [dpnp.get_usm_ndarray(x) for x in arrays]
515+
usm_res = dpt.concat(usm_arrays, axis=axis)
516+
res = dpnp_array._create_from_usm_ndarray(usm_res)
517+
if dtype is not None:
518+
res = res.astype(dtype, casting=casting, copy=False)
519+
elif out is not None:
520+
dpnp.copyto(out, res, casting=casting)
521+
return out
522+
return res
518523

519524

520525
def copyto(dst, src, casting="same_kind", where=True):
@@ -868,19 +873,21 @@ def hstack(tup, *, dtype=None, casting="same_kind"):
868873
869874
For full documentation refer to :obj:`numpy.hstack`.
870875
876+
Parameters
877+
----------
878+
tup : {dpnp.ndarray, usm_ndarray}
879+
The arrays must have the same shape along all but the second axis,
880+
except 1-D arrays which can be any length.
881+
dtype : str or dtype
882+
If provided, the destination array will have this dtype.
883+
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
884+
Controls what kind of data casting may occur. Defaults to 'same_kind'.
885+
871886
Returns
872887
-------
873888
out : dpnp.ndarray
874889
The stacked array which has one more dimension than the input arrays.
875890
876-
Limitations
877-
-----------
878-
Each array in `tup` is supported as either :class:`dpnp.ndarray`
879-
or :class:`dpctl.tensor.usm_ndarray`. Otherwise ``TypeError`` exception
880-
will be raised.
881-
Parameters `dtype` and `casting` are supported with default value.
882-
Otherwise the function will be executed sequentially on CPU.
883-
884891
See Also
885892
--------
886893
:obj:`dpnp.concatenate` : Join a sequence of arrays along an existing axis.
@@ -1357,26 +1364,32 @@ def squeeze(a, /, axis=None):
13571364
)
13581365

13591366

1360-
def stack(arrays, /, *, axis=0, out=None, dtype=None, **kwargs):
1367+
def stack(arrays, /, *, axis=0, out=None, dtype=None, casting="same_kind"):
13611368
"""
13621369
Join a sequence of arrays along a new axis.
13631370
13641371
For full documentation refer to :obj:`numpy.stack`.
13651372
1373+
Parameters
1374+
----------
1375+
arrays : {dpnp.ndarray, usm_ndarray}
1376+
Each array must have the same shape.
1377+
axis : int, optional
1378+
The axis in the result array along which the input arrays are stacked.
1379+
out : dpnp.ndarray, optional
1380+
If provided, the destination to place the result. The shape must be correct,
1381+
matching that of what stack would have returned if no out argument were specified.
1382+
dtype : str or dtype
1383+
If provided, the destination array will have this dtype. Cannot be provided
1384+
together with out.
1385+
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
1386+
Controls what kind of data casting may occur. Defaults to 'same_kind'.
1387+
13661388
Returns
13671389
-------
13681390
out : dpnp.ndarray
13691391
The stacked array which has one more dimension than the input arrays.
13701392
1371-
Limitations
1372-
-----------
1373-
Each array in `arrays` is supported as either :class:`dpnp.ndarray`
1374-
or :class:`dpctl.tensor.usm_ndarray`. Otherwise ``TypeError`` exception
1375-
will be raised.
1376-
Parameters `out` and `dtype` are supported with default value.
1377-
Keyword argument `kwargs` is currently unsupported.
1378-
Otherwise the function will be executed sequentially on CPU.
1379-
13801393
See Also
13811394
--------
13821395
:obj:`dpnp.concatenate` : Join a sequence of arrays along an existing axis.
@@ -1409,25 +1422,20 @@ def stack(arrays, /, *, axis=0, out=None, dtype=None, **kwargs):
14091422
14101423
"""
14111424

1412-
if kwargs:
1413-
pass
1425+
if dtype is not None and out is not None:
1426+
raise TypeError(
1427+
"stack() only takes `out` or `dtype` as an argument, but both were provided."
1428+
)
1429+
1430+
usm_arrays = [dpnp.get_usm_ndarray(x) for x in arrays]
1431+
usm_res = dpt.stack(usm_arrays, axis=axis)
1432+
res = dpnp_array._create_from_usm_ndarray(usm_res)
1433+
if dtype is not None:
1434+
res = res.astype(dtype, casting=casting, copy=False)
14141435
elif out is not None:
1415-
pass
1416-
elif dtype is not None:
1417-
pass
1418-
else:
1419-
usm_arrays = [dpnp.get_usm_ndarray(x) for x in arrays]
1420-
usm_res = dpt.stack(usm_arrays, axis=axis)
1421-
return dpnp_array._create_from_usm_ndarray(usm_res)
1422-
1423-
return call_origin(
1424-
numpy.stack,
1425-
arrays,
1426-
axis=axis,
1427-
out=out,
1428-
dtype=dtype,
1429-
**kwargs,
1430-
)
1436+
dpnp.copyto(out, res, casting=casting)
1437+
return out
1438+
return res
14311439

14321440

14331441
def swapaxes(a, axis1, axis2):
@@ -1649,19 +1657,21 @@ def vstack(tup, *, dtype=None, casting="same_kind"):
16491657
16501658
For full documentation refer to :obj:`numpy.vstack`.
16511659
1660+
Parameters
1661+
----------
1662+
tup : {dpnp.ndarray, usm_ndarray}
1663+
The arrays must have the same shape along all but the first axis.
1664+
1-D arrays must have the same length.
1665+
dtype : str or dtype
1666+
If provided, the destination array will have this dtype.
1667+
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
1668+
Controls what kind of data casting may occur. Defaults to 'same_kind'.
1669+
16521670
Returns
16531671
-------
16541672
out : dpnp.ndarray
16551673
The array formed by stacking the given arrays, will be at least 2-D.
16561674
1657-
Limitations
1658-
-----------
1659-
Each array in `tup` is supported as either :class:`dpnp.ndarray`
1660-
or :class:`dpctl.tensor.usm_ndarray`. Otherwise ``TypeError`` exception
1661-
will be raised.
1662-
Parameters `dtype` and `casting` are supported with default value.
1663-
Otherwise the function will be executed sequentially on CPU.
1664-
16651675
See Also
16661676
--------
16671677
:obj:`dpnp.concatenate` : Join a sequence of arrays along an existing axis.

tests/test_arraymanipulation.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,6 @@ def test_concatenate_3d(self, dtype):
305305
dp_res = dpnp.concatenate((dp_a0.T, dp_a1.T, dp_a2.T), axis=0)
306306
assert_array_equal(dp_res.asnumpy(), np_res)
307307

308-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
309308
@pytest.mark.parametrize(
310309
"dtype", get_all_dtypes(no_bool=True, no_none=True)
311310
)
@@ -329,7 +328,6 @@ def test_concatenate_out(self, dtype):
329328
assert_array_equal(dp_out.asnumpy(), np_out)
330329
assert_array_equal(dp_res.asnumpy(), np_res)
331330

332-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
333331
@pytest.mark.parametrize(
334332
"dtype", get_all_dtypes(no_bool=True, no_none=True)
335333
)
@@ -487,7 +485,6 @@ def test_empty_arrays_input(self, dtype):
487485
dp_res = dpnp.stack(dp_arrays, axis=1)
488486
assert_array_equal(dp_res.asnumpy(), np_res)
489487

490-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
491488
@pytest.mark.parametrize("dtype", get_all_dtypes())
492489
def test_out(self, dtype):
493490
np_a = numpy.array([1, 2, 3], dtype=dtype)
@@ -536,7 +533,6 @@ def test_generator_input(self):
536533
with pytest.raises(TypeError):
537534
dpnp.stack((x for x in range(3)))
538535

539-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
540536
@pytest.mark.usefixtures("suppress_complex_warning")
541537
@pytest.mark.parametrize("arr_dtype", get_all_dtypes())
542538
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
@@ -552,7 +548,6 @@ def test_casting_dtype(self, arr_dtype, dtype):
552548
dp_res = dpnp.stack((dp_a, dp_b), axis=1, casting="unsafe", dtype=dtype)
553549
assert_array_equal(dp_res.asnumpy(), np_res)
554550

555-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
556551
@pytest.mark.parametrize("arr_dtype", get_float_complex_dtypes())
557552
@pytest.mark.parametrize("dtype", [dpnp.bool, dpnp.int32, dpnp.int64])
558553
def test_invalid_casting_dtype(self, arr_dtype, dtype):
@@ -939,3 +934,17 @@ def test_can_cast():
939934
assert dpnp.can_cast(X, "float32") == numpy.can_cast(X_np, "float32")
940935
assert dpnp.can_cast(X, dpnp.int32) == numpy.can_cast(X_np, numpy.int32)
941936
assert dpnp.can_cast(X, dpnp.int64) == numpy.can_cast(X_np, numpy.int64)
937+
938+
939+
def test_concatenate_out_dtype():
940+
x = dpnp.ones((5, 5))
941+
out = dpnp.empty_like(x)
942+
with pytest.raises(TypeError):
943+
dpnp.concatenate([x], out=out, dtype="i4")
944+
945+
946+
def test_stack_out_dtype():
947+
x = dpnp.ones((5, 5))
948+
out = dpnp.empty_like(x)
949+
with pytest.raises(TypeError):
950+
dpnp.stack([x], out=out, dtype="i4")

0 commit comments

Comments
 (0)