Skip to content

Commit 6dc39f9

Browse files
authored
updating tests - Part1 (#2210)
This is part 1 of a series of PRs in which the tests are refactored. In this PR, `test_linalg.py`, `test_product.py`, `test_statistics.py`, `test_fft.py`, and `test_sort.py` are updated.
1 parent 4607833 commit 6dc39f9

File tree

9 files changed

+1111
-1280
lines changed

9 files changed

+1111
-1280
lines changed

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 2 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,7 @@
3737
3838
"""
3939

40-
4140
import numpy
42-
from dpctl.tensor._numpy_helper import normalize_axis_tuple
4341

4442
import dpnp
4543

@@ -48,6 +46,7 @@
4846
dpnp_dot,
4947
dpnp_kron,
5048
dpnp_matmul,
49+
dpnp_tensordot,
5150
dpnp_vecdot,
5251
)
5352

@@ -1047,65 +1046,7 @@ def tensordot(a, b, axes=2):
10471046
# TODO: use specific scalar-vector kernel
10481047
return dpnp.multiply(a, b)
10491048

1050-
try:
1051-
iter(axes)
1052-
except Exception as e: # pylint: disable=broad-exception-caught
1053-
if not isinstance(axes, int):
1054-
raise TypeError("Axes must be an integer.") from e
1055-
if axes < 0:
1056-
raise ValueError("Axes must be a non-negative integer.") from e
1057-
axes_a = tuple(range(-axes, 0))
1058-
axes_b = tuple(range(0, axes))
1059-
else:
1060-
if len(axes) != 2:
1061-
raise ValueError("Axes must consist of two sequences.")
1062-
1063-
axes_a, axes_b = axes
1064-
axes_a = (axes_a,) if dpnp.isscalar(axes_a) else axes_a
1065-
axes_b = (axes_b,) if dpnp.isscalar(axes_b) else axes_b
1066-
1067-
if len(axes_a) != len(axes_b):
1068-
raise ValueError("Axes length mismatch.")
1069-
1070-
# Make the axes non-negative
1071-
a_ndim = a.ndim
1072-
b_ndim = b.ndim
1073-
axes_a = normalize_axis_tuple(axes_a, a_ndim, "axis_a")
1074-
axes_b = normalize_axis_tuple(axes_b, b_ndim, "axis_b")
1075-
1076-
if a.ndim == 0 or b.ndim == 0:
1077-
# TODO: use specific scalar-vector kernel
1078-
return dpnp.multiply(a, b)
1079-
1080-
a_shape = a.shape
1081-
b_shape = b.shape
1082-
for axis_a, axis_b in zip(axes_a, axes_b):
1083-
if a_shape[axis_a] != b_shape[axis_b]:
1084-
raise ValueError(
1085-
"shape of input arrays is not similar at requested axes."
1086-
)
1087-
1088-
# Move the axes to sum over, to the end of "a"
1089-
not_in = tuple(k for k in range(a_ndim) if k not in axes_a)
1090-
newaxes_a = not_in + axes_a
1091-
n1 = int(numpy.prod([a_shape[ax] for ax in not_in]))
1092-
n2 = int(numpy.prod([a_shape[ax] for ax in axes_a]))
1093-
newshape_a = (n1, n2)
1094-
olda = [a_shape[axis] for axis in not_in]
1095-
1096-
# Move the axes to sum over, to the front of "b"
1097-
not_in = tuple(k for k in range(b_ndim) if k not in axes_b)
1098-
newaxes_b = tuple(axes_b + not_in)
1099-
n1 = int(numpy.prod([b_shape[ax] for ax in axes_b]))
1100-
n2 = int(numpy.prod([b_shape[ax] for ax in not_in]))
1101-
newshape_b = (n1, n2)
1102-
oldb = [b_shape[axis] for axis in not_in]
1103-
1104-
at = dpnp.transpose(a, newaxes_a).reshape(newshape_a)
1105-
bt = dpnp.transpose(b, newaxes_b).reshape(newshape_b)
1106-
res = dpnp.matmul(at, bt)
1107-
1108-
return res.reshape(olda + oldb)
1049+
return dpnp_tensordot(a, b, axes=axes)
11091050

11101051

11111052
def vdot(a, b):

dpnp/dpnp_iface_sorting.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,18 @@ def _wrap_sort_argsort(
6464

6565
if order is not None:
6666
raise NotImplementedError(
67-
"order keyword argument is only supported with its default value."
68-
)
69-
if kind is not None and stable is not None:
70-
raise ValueError(
71-
"`kind` and `stable` parameters can't be provided at the same time."
72-
" Use only one of them."
67+
"`order` keyword argument is only supported with its default value."
7368
)
69+
if stable is not None:
70+
if stable not in [True, False]:
71+
raise ValueError(
72+
"`stable` parameter should be None, True, or False."
73+
)
74+
if kind is not None:
75+
raise ValueError(
76+
"`kind` and `stable` parameters can't be provided at"
77+
" the same time. Use only one of them."
78+
)
7479

7580
usm_a = dpnp.get_usm_ndarray(a)
7681
if axis is None:

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,14 @@
3636
from dpnp.dpnp_array import dpnp_array
3737
from dpnp.dpnp_utils import get_usm_allocations
3838

39-
__all__ = ["dpnp_cross", "dpnp_dot", "dpnp_kron", "dpnp_matmul", "dpnp_vecdot"]
39+
__all__ = [
40+
"dpnp_cross",
41+
"dpnp_dot",
42+
"dpnp_kron",
43+
"dpnp_matmul",
44+
"dpnp_tensordot",
45+
"dpnp_vecdot",
46+
]
4047

4148

4249
def _compute_res_dtype(*arrays, sycl_queue, dtype=None, casting="no"):
@@ -974,6 +981,70 @@ def dpnp_matmul(
974981
return result
975982

976983

984+
def dpnp_tensordot(a, b, axes=2):
985+
"""Tensor dot product of two arrays."""
986+
987+
try:
988+
iter(axes)
989+
except Exception as e: # pylint: disable=broad-exception-caught
990+
if not isinstance(axes, int):
991+
raise TypeError("Axes must be an integer.") from e
992+
if axes < 0:
993+
raise ValueError("Axes must be a non-negative integer.") from e
994+
axes_a = tuple(range(-axes, 0))
995+
axes_b = tuple(range(0, axes))
996+
else:
997+
if len(axes) != 2:
998+
raise ValueError("Axes must consist of two sequences.")
999+
1000+
axes_a, axes_b = axes
1001+
axes_a = (axes_a,) if dpnp.isscalar(axes_a) else axes_a
1002+
axes_b = (axes_b,) if dpnp.isscalar(axes_b) else axes_b
1003+
1004+
if len(axes_a) != len(axes_b):
1005+
raise ValueError("Axes length mismatch.")
1006+
1007+
# Make the axes non-negative
1008+
a_ndim = a.ndim
1009+
b_ndim = b.ndim
1010+
axes_a = normalize_axis_tuple(axes_a, a_ndim, "axis_a")
1011+
axes_b = normalize_axis_tuple(axes_b, b_ndim, "axis_b")
1012+
1013+
if a.ndim == 0 or b.ndim == 0:
1014+
# TODO: use specific scalar-vector kernel
1015+
return dpnp.multiply(a, b)
1016+
1017+
a_shape = a.shape
1018+
b_shape = b.shape
1019+
for axis_a, axis_b in zip(axes_a, axes_b):
1020+
if a_shape[axis_a] != b_shape[axis_b]:
1021+
raise ValueError(
1022+
"shape of input arrays is not similar at requested axes."
1023+
)
1024+
1025+
# Move the axes to sum over, to the end of "a"
1026+
not_in = tuple(k for k in range(a_ndim) if k not in axes_a)
1027+
newaxes_a = not_in + axes_a
1028+
n1 = int(numpy.prod([a_shape[ax] for ax in not_in]))
1029+
n2 = int(numpy.prod([a_shape[ax] for ax in axes_a]))
1030+
newshape_a = (n1, n2)
1031+
olda = [a_shape[axis] for axis in not_in]
1032+
1033+
# Move the axes to sum over, to the front of "b"
1034+
not_in = tuple(k for k in range(b_ndim) if k not in axes_b)
1035+
newaxes_b = tuple(axes_b + not_in)
1036+
n1 = int(numpy.prod([b_shape[ax] for ax in axes_b]))
1037+
n2 = int(numpy.prod([b_shape[ax] for ax in not_in]))
1038+
newshape_b = (n1, n2)
1039+
oldb = [b_shape[axis] for axis in not_in]
1040+
1041+
at = dpnp.transpose(a, newaxes_a).reshape(newshape_a)
1042+
bt = dpnp.transpose(b, newaxes_b).reshape(newshape_b)
1043+
res = dpnp.matmul(at, bt)
1044+
1045+
return res.reshape(olda + oldb)
1046+
1047+
9771048
def dpnp_vecdot(
9781049
x1,
9791050
x2,

0 commit comments

Comments
 (0)