Skip to content

Commit 543d76e

Browse files
authored
implement dpnp.vecdot and dpnp.linalg.vecdot (#2112)
* implement dpnp.vecdot and dpnp.linalg.vecdot * improve coverage * address comments
1 parent b617d6c commit 543d76e

File tree

10 files changed

+864
-144
lines changed

10 files changed

+864
-144
lines changed

doc/reference/binary.rst

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
Bit-wise operations
2-
=================
2+
===================
33

44
.. https://numpy.org/doc/stable/reference/routines.bitwise.html
55
@@ -22,7 +22,6 @@ Element-wise bit operations
2222
dpnp.bitwise_right_shift
2323
dpnp.bitwise_count
2424

25-
2625
Bit packing
2726
-----------
2827

@@ -33,7 +32,6 @@ Bit packing
3332
dpnp.packbits
3433
dpnp.unpackbits
3534

36-
3735
Output formatting
3836
-----------------
3937

doc/reference/linalg.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ Other matrix operations
8787
-----------------------
8888
.. autosummary::
8989
:toctree: generated/
90+
:nosignatures:
9091

9192
dpnp.diagonal
9293
dpnp.linalg.diagonal (Array API compatible)
@@ -96,5 +97,6 @@ Exceptions
9697
----------
9798
.. autosummary::
9899
:toctree: generated/
100+
:nosignatures:
99101

100102
dpnp.linalg.linAlgError

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 131 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
dpnp_dot,
4949
dpnp_kron,
5050
dpnp_matmul,
51+
dpnp_vecdot,
5152
)
5253

5354
__all__ = [
@@ -60,6 +61,7 @@
6061
"outer",
6162
"tensordot",
6263
"vdot",
64+
"vecdot",
6365
]
6466

6567

@@ -145,11 +147,11 @@ def dot(a, b, out=None):
145147
# TODO: use specific scalar-vector kernel
146148
return dpnp.multiply(a, b, out=out)
147149

150+
# numpy.dot does not allow casting even if it is safe
151+
# casting="no" is used in the following
148152
if a_ndim == 1 and b_ndim == 1:
149-
return dpnp_dot(a, b, out=out)
153+
return dpnp_dot(a, b, out=out, casting="no")
150154

151-
# NumPy does not allow casting even if it is safe
152-
# casting="no" is used in the following
153155
if a_ndim == 2 and b_ndim == 2:
154156
return dpnp.matmul(a, b, out=out, casting="no")
155157

@@ -728,7 +730,6 @@ def matmul(
728730
dtype=None,
729731
subok=True,
730732
signature=None,
731-
extobj=None,
732733
axes=None,
733734
axis=None,
734735
):
@@ -752,18 +753,19 @@ def matmul(
752753
Type to use in computing the matrix product. By default, the returned
753754
array will have data type that is determined by considering
754755
Promotion Type Rule and device capabilities.
756+
Default: ``None``.
755757
casting : {"no", "equiv", "safe", "same_kind", "unsafe"}, optional
756758
Controls what kind of data casting may occur.
757759
Default: ``"same_kind"``.
758760
order : {"C", "F", "A", "K", None}, optional
759761
Memory layout of the newly output array, if parameter `out` is ``None``.
760762
Default: ``"K"``.
761-
axes : {list of tuples}, optional
763+
axes : {None, list of tuples}, optional
762764
A list of tuples with indices of axes the matrix product should operate
763765
on. For instance, for the signature of ``(i,j),(j,k)->(i,k)``, the base
764766
elements are 2d matrices and these are taken to be stored in the two
765767
last axes of each argument. The corresponding axes keyword would be
766-
[(-2, -1), (-2, -1), (-2, -1)].
768+
``[(-2, -1), (-2, -1), (-2, -1)]``.
767769
Default: ``None``.
768770
769771
Returns
@@ -774,8 +776,8 @@ def matmul(
774776
775777
Limitations
776778
-----------
777-
Keyword arguments `subok`, `signature`, `extobj`, and `axis` are
778-
only supported with their default value.
779+
Keyword arguments `subok`, `signature`, and `axis` are only supported with
780+
their default values.
779781
Otherwise ``NotImplementedError`` exception will be raised.
780782
781783
See Also
@@ -834,18 +836,14 @@ def matmul(
834836
835837
"""
836838

837-
if subok is False:
839+
if not subok:
838840
raise NotImplementedError(
839841
"subok keyword argument is only supported by its default value."
840842
)
841843
if signature is not None:
842844
raise NotImplementedError(
843845
"signature keyword argument is only supported by its default value."
844846
)
845-
if extobj is not None:
846-
raise NotImplementedError(
847-
"extobj keyword argument is only supported by its default value."
848-
)
849847
if axis is not None:
850848
raise NotImplementedError(
851849
"axis keyword argument is only supported by its default value."
@@ -1135,6 +1133,9 @@ def vdot(a, b):
11351133
--------
11361134
:obj:`dpnp.dot` : Returns the dot product.
11371135
:obj:`dpnp.matmul` : Returns the matrix product.
1136+
:obj:`dpnp.vecdot` : Vector dot product of two arrays.
1137+
:obj:`dpnp.linalg.vecdot` : Array API compatible version of
1138+
:obj:`dpnp.vecdot`.
11381139
11391140
Examples
11401141
--------
@@ -1178,3 +1179,120 @@ def vdot(a, b):
11781179

11791180
# dot product of flatten arrays
11801181
return dpnp_dot(dpnp.ravel(a), dpnp.ravel(b), out=None, conjugate=True)
1182+
1183+
1184+
def vecdot(
1185+
x1,
1186+
x2,
1187+
/,
1188+
out=None,
1189+
*,
1190+
casting="same_kind",
1191+
order="K",
1192+
dtype=None,
1193+
subok=True,
1194+
signature=None,
1195+
axes=None,
1196+
axis=None,
1197+
):
1198+
r"""
1199+
Computes the vector dot product.
1200+
1201+
Let :math:`\mathbf{a}` be a vector in `x1` and :math:`\mathbf{b}` be
1202+
a corresponding vector in `x2`. The dot product is defined as:
1203+
1204+
.. math::
1205+
\mathbf{a} \cdot \mathbf{b} = \sum_{i=0}^{n-1} \overline{a_i}b_i
1206+
1207+
where the sum is over the last dimension (unless `axis` is specified) and
1208+
where :math:`\overline{a_i}` denotes the complex conjugate if :math:`a_i`
1209+
is complex and the identity otherwise.
1210+
1211+
For full documentation refer to :obj:`numpy.vecdot`.
1212+
1213+
Parameters
1214+
----------
1215+
x1 : {dpnp.ndarray, usm_ndarray}
1216+
First input array.
1217+
x2 : {dpnp.ndarray, usm_ndarray}
1218+
Second input array.
1219+
out : {None, dpnp.ndarray, usm_ndarray}, optional
1220+
A location into which the result is stored. If provided, it must have
1221+
a shape that the broadcasted shape of `x1` and `x2` with the last axis
1222+
removed. If not provided or ``None``, a freshly-allocated array is
1223+
used.
1224+
Default: ``None``.
1225+
casting : {"no", "equiv", "safe", "same_kind", "unsafe"}, optional
1226+
Controls what kind of data casting may occur.
1227+
Default: ``"same_kind"``.
1228+
order : {"C", "F", "A", "K", None}, optional
1229+
Memory layout of the newly output array, if parameter `out` is ``None``.
1230+
Default: ``"K"``.
1231+
dtype : {None, dtype}, optional
1232+
Type to use in computing the vector dot product. By default, the
1233+
returned array will have data type that is determined by considering
1234+
Promotion Type Rule and device capabilities.
1235+
Default: ``None``.
1236+
axes : {None, list of tuples}, optional
1237+
A list of tuples with indices of axes the matrix product should operate
1238+
on. For instance, for the signature of ``(i),(i)->()``, the base
1239+
elements are vectors and these are taken to be stored in the last axes
1240+
of each argument. The corresponding axes keyword would be
1241+
``[(-1,), (-1), ()]``.
1242+
Default: ``None``.
1243+
axis : {None, int}, optional
1244+
Axis over which to compute the dot product. This is a short-cut for
1245+
passing in axes with entries of ``(axis,)`` for each
1246+
single-core-dimension argument and ``()`` for all others. For instance,
1247+
for a signature ``(i),(i)->()``, it is equivalent to passing in
1248+
``axes=[(axis,), (axis,), ()]``.
1249+
Default: ``None``.
1250+
1251+
Returns
1252+
-------
1253+
out : dpnp.ndarray
1254+
The vector dot product of the inputs.
1255+
This is a 0-d array only when both `x1`, `x2` are 1-d vectors.
1256+
1257+
Limitations
1258+
-----------
1259+
Keyword arguments `subok`, and `signature` are only supported with their
1260+
default values. Otherwise ``NotImplementedError`` exception will be raised.
1261+
1262+
See Also
1263+
--------
1264+
:obj:`dpnp.linalg.vecdot` : Array API compatible version.
1265+
:obj:`dpnp.vdot` : Complex-conjugating dot product.
1266+
:obj:`dpnp.einsum` : Einstein summation convention.
1267+
1268+
Examples
1269+
--------
1270+
Get the projected size along a given normal for an array of vectors.
1271+
1272+
>>> import dpnp as np
1273+
>>> v = np.array([[0., 5., 0.], [0., 0., 10.], [0., 6., 8.]])
1274+
>>> n = np.array([0., 0.6, 0.8])
1275+
>>> np.vecdot(v, n)
1276+
array([ 3., 8., 10.])
1277+
1278+
"""
1279+
1280+
if not subok:
1281+
raise NotImplementedError(
1282+
"subok keyword argument is only supported by its default value."
1283+
)
1284+
if signature is not None:
1285+
raise NotImplementedError(
1286+
"signature keyword argument is only supported by its default value."
1287+
)
1288+
1289+
return dpnp_vecdot(
1290+
x1,
1291+
x2,
1292+
out=out,
1293+
casting=casting,
1294+
order=order,
1295+
dtype=dtype,
1296+
axes=axes,
1297+
axis=axis,
1298+
)

0 commit comments

Comments
 (0)