Skip to content

Commit bd7b2ca

Browse files
committed
implement dpnp.vecdot and dpnp.linalg.vecdot
1 parent 6c945e2 commit bd7b2ca

File tree

10 files changed

+771
-111
lines changed

10 files changed

+771
-111
lines changed

doc/reference/binary.rst

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
Bit-wise operations
2-
=================
2+
===================
33

44
.. https://numpy.org/doc/stable/reference/routines.bitwise.html
55
@@ -22,7 +22,6 @@ Element-wise bit operations
2222
dpnp.bitwise_right_shift
2323
dpnp.bitwise_count
2424

25-
2625
Bit packing
2726
-----------
2827

@@ -33,7 +32,6 @@ Bit packing
3332
dpnp.packbits
3433
dpnp.unpackbits
3534

36-
3735
Output formatting
3836
-----------------
3937

doc/reference/linalg.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ Norms and other numbers
6969
dpnp.trace
7070
dpnp.linalg.trace (Array API compatible)
7171

72-
7372
Solving linear equations
7473
--------------------------
7574

@@ -88,6 +87,7 @@ Other matrix operations
8887
-----------------------
8988
.. autosummary::
9089
:toctree: generated/
90+
:nosignatures:
9191

9292
dpnp.diagonal
9393
dpnp.linalg.diagonal (Array API compatible)
@@ -97,5 +97,6 @@ Exceptions
9797
----------
9898
.. autosummary::
9999
:toctree: generated/
100+
:nosignatures:
100101

101102
dpnp.linalg.linAlgError

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 126 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
dpnp_dot,
4949
dpnp_kron,
5050
dpnp_matmul,
51+
dpnp_vecdot,
5152
)
5253

5354
__all__ = [
@@ -60,6 +61,7 @@
6061
"outer",
6162
"tensordot",
6263
"vdot",
64+
"vecdot",
6365
]
6466

6567

@@ -728,7 +730,6 @@ def matmul(
728730
dtype=None,
729731
subok=True,
730732
signature=None,
731-
extobj=None,
732733
axes=None,
733734
axis=None,
734735
):
@@ -758,12 +759,12 @@ def matmul(
758759
order : {"C", "F", "A", "K", None}, optional
759760
Memory layout of the newly output array, if parameter `out` is ``None``.
760761
Default: ``"K"``.
761-
axes : {list of tuples}, optional
762+
axes : {None, list of tuples}, optional
762763
A list of tuples with indices of axes the matrix product should operate
763764
on. For instance, for the signature of ``(i,j),(j,k)->(i,k)``, the base
764765
elements are 2d matrices and these are taken to be stored in the two
765766
last axes of each argument. The corresponding axes keyword would be
766-
[(-2, -1), (-2, -1), (-2, -1)].
767+
``[(-2, -1), (-2, -1), (-2, -1)]``.
767768
Default: ``None``.
768769
769770
Returns
@@ -774,8 +775,8 @@ def matmul(
774775
775776
Limitations
776777
-----------
777-
Keyword arguments `subok`, `signature`, `extobj`, and `axis` are
778-
only supported with their default value.
778+
Keyword arguments `subok`, `signature`, and `axis` are only supported with
779+
their default values.
779780
Otherwise ``NotImplementedError`` exception will be raised.
780781
781782
See Also
@@ -834,18 +835,14 @@ def matmul(
834835
835836
"""
836837

837-
if subok is False:
838+
if not subok:
838839
raise NotImplementedError(
839840
"subok keyword argument is only supported by its default value."
840841
)
841842
if signature is not None:
842843
raise NotImplementedError(
843844
"signature keyword argument is only supported by its default value."
844845
)
845-
if extobj is not None:
846-
raise NotImplementedError(
847-
"extobj keyword argument is only supported by its default value."
848-
)
849846
if axis is not None:
850847
raise NotImplementedError(
851848
"axis keyword argument is only supported by its default value."
@@ -1135,6 +1132,9 @@ def vdot(a, b):
11351132
--------
11361133
:obj:`dpnp.dot` : Returns the dot product.
11371134
:obj:`dpnp.matmul` : Returns the matrix product.
1135+
:obj:`dpnp.vecdot` : Vector dot product of two arrays.
1136+
:obj:`dpnp.linalg.vecdot` : Array API compatible version of
1137+
:obj:`dpnp.vecdot`.
11381138
11391139
Examples
11401140
--------
@@ -1178,3 +1178,119 @@ def vdot(a, b):
11781178

11791179
# dot product of flatten arrays
11801180
return dpnp_dot(dpnp.ravel(a), dpnp.ravel(b), out=None, conjugate=True)
1181+
1182+
1183+
def vecdot(
1184+
x1,
1185+
x2,
1186+
/,
1187+
out=None,
1188+
*,
1189+
casting="same_kind",
1190+
order="K",
1191+
dtype=None,
1192+
subok=True,
1193+
signature=None,
1194+
axes=None,
1195+
axis=None,
1196+
):
1197+
r"""
1198+
Computes the vector dot product.
1199+
1200+
Let :math:`\mathbf{a}` be a vector in `x1` and :math:`\mathbf{b}` be
1201+
a corresponding vector in `x2`. The dot product is defined as:
1202+
1203+
.. math::
1204+
\mathbf{a} \cdot \mathbf{b} = \sum_{i=0}^{n-1} \overline{a_i}b_i
1205+
1206+
where the sum is over the last dimension (unless axis is specified) and
1207+
where :math:`\overline{a_i}` denotes the complex conjugate if :math:`a_i`
1208+
is complex and the identity otherwise.
1209+
1210+
For full documentation refer to :obj:`numpy.vecdot`.
1211+
1212+
Parameters
1213+
----------
1214+
x1 : {dpnp.ndarray, usm_ndarray}
1215+
First input array.
1216+
x2 : {dpnp.ndarray, usm_ndarray}
1217+
Second input array.
1218+
out : {None, dpnp.ndarray, usm_ndarray}, optional
1219+
A location into which the result is stored. If provided, it must have
1220+
a shape that the broadcasted shape of `x1` and `x2` with the last axis
1221+
removed. If not provided or ``None``, a freshly-allocated array is
1222+
used.
1223+
Default: ``None``.
1224+
dtype : {None, dtype}, optional
1225+
Type to use in computing the vector dot product. By default, the
1226+
returned array will have data type that is determined by considering
1227+
Promotion Type Rule and device capabilities.
1228+
casting : {"no", "equiv", "safe", "same_kind", "unsafe"}, optional
1229+
Controls what kind of data casting may occur.
1230+
Default: ``"same_kind"``.
1231+
order : {"C", "F", "A", "K", None}, optional
1232+
Memory layout of the newly output array, if parameter `out` is ``None``.
1233+
Default: ``"K"``.
1234+
axes : {None, list of tuples}, optional
1235+
A list of tuples with indices of axes the matrix product should operate
1236+
on. For instance, for the signature of ``(i),(i)->()``, the base
1237+
elements are vectors and these are taken to be stored in the last axes
1238+
of each argument. The corresponding axes keyword would be
1239+
``[(-1,), (-1), ()]``.
1240+
Default: ``None``.
1241+
axis : {None, int}, optional
1242+
Axis over which to compute the dot product. This is a short-cut for
1243+
passing in axes with entries of ``(axis,)`` for each
1244+
single-core-dimension argument and ``()`` for all others. For instance,
1245+
for a signature ``(i),(i)->()``, it is equivalent to passing in
1246+
``axes=[(axis,), (axis,), ()]``.
1247+
Default: ``None``.
1248+
1249+
Returns
1250+
-------
1251+
out : dpnp.ndarray
1252+
The vector dot product of the inputs.
1253+
This is a 0-d array only when both `x1`, `x2` are 1-d vectors.
1254+
1255+
Limitations
1256+
-----------
1257+
Keyword arguments `subok`, and `signature` are only supported with their
1258+
default values. Otherwise ``NotImplementedError`` exception will be raised.
1259+
1260+
See Also
1261+
--------
1262+
:obj:`dpnp.linalg.vecdot` : Array API compatible version.
1263+
:obj:`dpnp.vdot` : Complex-conjugating dot product.
1264+
:obj:`dpnp.einsum` : Einstein summation convention.
1265+
1266+
Examples
1267+
--------
1268+
Get the projected size along a given normal for an array of vectors.
1269+
1270+
>>> import dpnp as np
1271+
>>> v = np.array([[0., 5., 0.], [0., 0., 10.], [0., 6., 8.]])
1272+
>>> n = np.array([0., 0.6, 0.8])
1273+
>>> np.vecdot(v, n)
1274+
array([ 3., 8., 10.])
1275+
1276+
"""
1277+
1278+
if not subok:
1279+
raise NotImplementedError(
1280+
"subok keyword argument is only supported by its default value."
1281+
)
1282+
if signature is not None:
1283+
raise NotImplementedError(
1284+
"signature keyword argument is only supported by its default value."
1285+
)
1286+
1287+
return dpnp_vecdot(
1288+
x1,
1289+
x2,
1290+
out=out,
1291+
casting=casting,
1292+
order=order,
1293+
dtype=dtype,
1294+
axes=axes,
1295+
axis=axis,
1296+
)

0 commit comments

Comments
 (0)