Skip to content

updating tests - Part1 #2210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 2 additions & 61 deletions dpnp/dpnp_iface_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@

"""


import numpy
from dpctl.tensor._numpy_helper import normalize_axis_tuple

import dpnp

Expand All @@ -48,6 +46,7 @@
dpnp_dot,
dpnp_kron,
dpnp_matmul,
dpnp_tensordot,
dpnp_vecdot,
)

Expand Down Expand Up @@ -1047,65 +1046,7 @@ def tensordot(a, b, axes=2):
# TODO: use specific scalar-vector kernel
return dpnp.multiply(a, b)

try:
iter(axes)
except Exception as e: # pylint: disable=broad-exception-caught
if not isinstance(axes, int):
raise TypeError("Axes must be an integer.") from e
if axes < 0:
raise ValueError("Axes must be a non-negative integer.") from e
axes_a = tuple(range(-axes, 0))
axes_b = tuple(range(0, axes))
else:
if len(axes) != 2:
raise ValueError("Axes must consist of two sequences.")

axes_a, axes_b = axes
axes_a = (axes_a,) if dpnp.isscalar(axes_a) else axes_a
axes_b = (axes_b,) if dpnp.isscalar(axes_b) else axes_b

if len(axes_a) != len(axes_b):
raise ValueError("Axes length mismatch.")

# Make the axes non-negative
a_ndim = a.ndim
b_ndim = b.ndim
axes_a = normalize_axis_tuple(axes_a, a_ndim, "axis_a")
axes_b = normalize_axis_tuple(axes_b, b_ndim, "axis_b")

if a.ndim == 0 or b.ndim == 0:
# TODO: use specific scalar-vector kernel
return dpnp.multiply(a, b)

a_shape = a.shape
b_shape = b.shape
for axis_a, axis_b in zip(axes_a, axes_b):
if a_shape[axis_a] != b_shape[axis_b]:
raise ValueError(
"shape of input arrays is not similar at requested axes."
)

# Move the axes to sum over, to the end of "a"
not_in = tuple(k for k in range(a_ndim) if k not in axes_a)
newaxes_a = not_in + axes_a
n1 = int(numpy.prod([a_shape[ax] for ax in not_in]))
n2 = int(numpy.prod([a_shape[ax] for ax in axes_a]))
newshape_a = (n1, n2)
olda = [a_shape[axis] for axis in not_in]

# Move the axes to sum over, to the front of "b"
not_in = tuple(k for k in range(b_ndim) if k not in axes_b)
newaxes_b = tuple(axes_b + not_in)
n1 = int(numpy.prod([b_shape[ax] for ax in axes_b]))
n2 = int(numpy.prod([b_shape[ax] for ax in not_in]))
newshape_b = (n1, n2)
oldb = [b_shape[axis] for axis in not_in]

at = dpnp.transpose(a, newaxes_a).reshape(newshape_a)
bt = dpnp.transpose(b, newaxes_b).reshape(newshape_b)
res = dpnp.matmul(at, bt)

return res.reshape(olda + oldb)
return dpnp_tensordot(a, b, axes=axes)


def vdot(a, b):
Expand Down
17 changes: 11 additions & 6 deletions dpnp/dpnp_iface_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,18 @@ def _wrap_sort_argsort(

if order is not None:
raise NotImplementedError(
"order keyword argument is only supported with its default value."
)
if kind is not None and stable is not None:
raise ValueError(
"`kind` and `stable` parameters can't be provided at the same time."
" Use only one of them."
"`order` keyword argument is only supported with its default value."
)
if stable is not None:
if stable not in [True, False]:
raise ValueError(
"`stable` parameter should be None, True, or False."
)
if kind is not None:
raise ValueError(
"`kind` and `stable` parameters can't be provided at"
" the same time. Use only one of them."
)

usm_a = dpnp.get_usm_ndarray(a)
if axis is None:
Expand Down
73 changes: 72 additions & 1 deletion dpnp/dpnp_utils/dpnp_utils_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,14 @@
from dpnp.dpnp_array import dpnp_array
from dpnp.dpnp_utils import get_usm_allocations

__all__ = ["dpnp_cross", "dpnp_dot", "dpnp_kron", "dpnp_matmul", "dpnp_vecdot"]
__all__ = [
"dpnp_cross",
"dpnp_dot",
"dpnp_kron",
"dpnp_matmul",
"dpnp_tensordot",
"dpnp_vecdot",
]


def _compute_res_dtype(*arrays, sycl_queue, dtype=None, casting="no"):
Expand Down Expand Up @@ -974,6 +981,70 @@ def dpnp_matmul(
return result


def dpnp_tensordot(a, b, axes=2):
"""Tensor dot product of two arrays."""

try:
iter(axes)
except Exception as e: # pylint: disable=broad-exception-caught
if not isinstance(axes, int):
raise TypeError("Axes must be an integer.") from e
if axes < 0:
raise ValueError("Axes must be a non-negative integer.") from e
axes_a = tuple(range(-axes, 0))
axes_b = tuple(range(0, axes))
else:
if len(axes) != 2:
raise ValueError("Axes must consist of two sequences.")

axes_a, axes_b = axes
axes_a = (axes_a,) if dpnp.isscalar(axes_a) else axes_a
axes_b = (axes_b,) if dpnp.isscalar(axes_b) else axes_b

if len(axes_a) != len(axes_b):
raise ValueError("Axes length mismatch.")

# Make the axes non-negative
a_ndim = a.ndim
b_ndim = b.ndim
axes_a = normalize_axis_tuple(axes_a, a_ndim, "axis_a")
axes_b = normalize_axis_tuple(axes_b, b_ndim, "axis_b")

if a.ndim == 0 or b.ndim == 0:
# TODO: use specific scalar-vector kernel
return dpnp.multiply(a, b)

a_shape = a.shape
b_shape = b.shape
for axis_a, axis_b in zip(axes_a, axes_b):
if a_shape[axis_a] != b_shape[axis_b]:
raise ValueError(
"shape of input arrays is not similar at requested axes."
)

# Move the axes to sum over, to the end of "a"
not_in = tuple(k for k in range(a_ndim) if k not in axes_a)
newaxes_a = not_in + axes_a
n1 = int(numpy.prod([a_shape[ax] for ax in not_in]))
n2 = int(numpy.prod([a_shape[ax] for ax in axes_a]))
newshape_a = (n1, n2)
olda = [a_shape[axis] for axis in not_in]

# Move the axes to sum over, to the front of "b"
not_in = tuple(k for k in range(b_ndim) if k not in axes_b)
newaxes_b = tuple(axes_b + not_in)
n1 = int(numpy.prod([b_shape[ax] for ax in axes_b]))
n2 = int(numpy.prod([b_shape[ax] for ax in not_in]))
newshape_b = (n1, n2)
oldb = [b_shape[axis] for axis in not_in]

at = dpnp.transpose(a, newaxes_a).reshape(newshape_a)
bt = dpnp.transpose(b, newaxes_b).reshape(newshape_b)
res = dpnp.matmul(at, bt)

return res.reshape(olda + oldb)


def dpnp_vecdot(
x1,
x2,
Expand Down
Loading
Loading