Skip to content

Commit a8bcdaf

Browse files
vtavanaantonwolfy
andauthored
impelement dpnp.norm (#1746)
* impelement dpnp.norm * address comments * add float ty description * improve test coverage * make axis for test_norm_ND and ND_complex similar * mute some tests for on windows * unmute test on windows --------- Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com>
1 parent 20264e8 commit a8bcdaf

File tree

12 files changed

+816
-338
lines changed

12 files changed

+816
-338
lines changed

dpnp/backend/extensions/blas/blas_py.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,15 @@ PYBIND11_MODULE(_blas_impl, m)
5252

5353
{
5454
m.def("_dot", &blas_ext::dot,
55-
"Call `dot` from OneMKL LAPACK library to return "
55+
"Call `dot` from OneMKL BLAS library to return "
5656
"the dot product of two real-valued vectors.",
5757
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
5858
py::arg("result"), py::arg("depends") = py::list());
5959
}
6060

6161
{
6262
m.def("_dotc", &blas_ext::dotc,
63-
"Call `dotc` from OneMKL LAPACK library to return "
63+
"Call `dotc` from OneMKL BLAS library to return "
6464
"the dot product of two complex vectors, "
6565
"conjugating the first vector.",
6666
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
@@ -69,23 +69,23 @@ PYBIND11_MODULE(_blas_impl, m)
6969

7070
{
7171
m.def("_dotu", &blas_ext::dotu,
72-
"Call `dotu` from OneMKL LAPACK library to return "
72+
"Call `dotu` from OneMKL BLAS library to return "
7373
"the dot product of two complex vectors.",
7474
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
7575
py::arg("result"), py::arg("depends") = py::list());
7676
}
7777

7878
{
7979
m.def("_gemm", &blas_ext::gemm,
80-
"Call `gemm` from OneMKL LAPACK library to return "
80+
"Call `gemm` from OneMKL BLAS library to return "
8181
"the matrix-matrix product with 2-D matrices.",
8282
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
8383
py::arg("result"), py::arg("depends") = py::list());
8484
}
8585

8686
{
8787
m.def("_gemm_batch", &blas_ext::gemm_batch,
88-
"Call `gemm_batch` from OneMKL LAPACK library to return "
88+
"Call `gemm_batch` from OneMKL BLAS library to return "
8989
"the matrix-matrix product for a batch of 2-D matrices.",
9090
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
9191
py::arg("result"), py::arg("batch_size"), py::arg("stridea"),

dpnp/linalg/dpnp_algo_linalg.pyx

Lines changed: 0 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ __all__ = [
4848
"dpnp_cond",
4949
"dpnp_eig",
5050
"dpnp_eigvals",
51-
"dpnp_norm",
5251
]
5352

5453

@@ -171,108 +170,3 @@ cpdef utils.dpnp_descriptor dpnp_eigvals(utils.dpnp_descriptor input):
171170
c_dpctl.DPCTLEvent_Delete(event_ref)
172171

173172
return res_val
174-
175-
176-
cpdef object dpnp_norm(object input, ord=None, axis=None):
177-
cdef long size_input = input.size
178-
cdef shape_type_c shape_input = input.shape
179-
180-
dev = input.get_array().sycl_device
181-
if input.dtype == dpnp.float32 or not dev.has_aspect_fp64:
182-
res_type = dpnp.float32
183-
else:
184-
res_type = dpnp.float64
185-
186-
if size_input == 0:
187-
return dpnp.array([dpnp.nan], dtype=res_type)
188-
189-
if isinstance(axis, int):
190-
axis_ = tuple([axis])
191-
else:
192-
axis_ = axis
193-
194-
ndim = input.ndim
195-
if axis is None:
196-
if ((ord is None) or
197-
(ord in ('f', 'fro') and ndim == 2) or
198-
(ord == 2 and ndim == 1)):
199-
200-
# TODO: change order='K' when support is implemented
201-
input = dpnp.ravel(input, order='C')
202-
sqnorm = dpnp.dot(input, input)
203-
ret = dpnp.sqrt([sqnorm], dtype=res_type)
204-
return dpnp.array(ret.reshape(1, *ret.shape), dtype=res_type)
205-
206-
len_axis = 1 if axis is None else len(axis_)
207-
if len_axis == 1:
208-
if ord == dpnp.inf:
209-
return dpnp.array([dpnp.abs(input).max(axis=axis)])
210-
elif ord == -dpnp.inf:
211-
return dpnp.array([dpnp.abs(input).min(axis=axis)])
212-
elif ord == 0:
213-
return input.dtype.type(dpnp.count_nonzero(input, axis=axis))
214-
elif ord is None or ord == 2:
215-
s = input * input
216-
return dpnp.sqrt(dpnp.sum(s, axis=axis), dtype=res_type)
217-
elif isinstance(ord, str):
218-
raise ValueError(f"Invalid norm order '{ord}' for vectors")
219-
else:
220-
absx = dpnp.abs(input)
221-
absx_size = absx.size
222-
absx_power = utils_py.create_output_descriptor_py((absx_size,), absx.dtype, None).get_pyobj()
223-
224-
absx_flatiter = absx.flat
225-
226-
for i in range(absx_size):
227-
absx_elem = absx_flatiter[i]
228-
absx_power[i] = absx_elem ** ord
229-
absx_ = dpnp.reshape(absx_power, absx.shape)
230-
ret = dpnp.sum(absx_, axis=axis)
231-
ret_size = ret.size
232-
ret_power = utils_py.create_output_descriptor_py((ret_size,), None, None).get_pyobj()
233-
234-
ret_flatiter = ret.flat
235-
236-
for i in range(ret_size):
237-
ret_elem = ret_flatiter[i]
238-
ret_power[i] = ret_elem ** (1 / ord)
239-
ret_ = dpnp.reshape(ret_power, ret.shape)
240-
return ret_
241-
elif len_axis == 2:
242-
row_axis, col_axis = axis_
243-
if row_axis == col_axis:
244-
raise ValueError('Duplicate axes given.')
245-
# if ord == 2:
246-
# ret = _multi_svd_norm(input, row_axis, col_axis, amax)
247-
# elif ord == -2:
248-
# ret = _multi_svd_norm(input, row_axis, col_axis, amin)
249-
elif ord == 1:
250-
if col_axis > row_axis:
251-
col_axis -= 1
252-
dpnp_sum_val = dpnp.sum(dpnp.abs(input), axis=row_axis)
253-
ret = dpnp_sum_val.min(axis=col_axis)
254-
elif ord == dpnp.inf:
255-
if row_axis > col_axis:
256-
row_axis -= 1
257-
dpnp_sum_val = dpnp.sum(dpnp.abs(input), axis=col_axis)
258-
ret = dpnp_sum_val.max(axis=row_axis)
259-
elif ord == -1:
260-
if col_axis > row_axis:
261-
col_axis -= 1
262-
dpnp_sum_val = dpnp.sum(dpnp.abs(input), axis=row_axis)
263-
ret = dpnp_sum_val.min(axis=col_axis)
264-
elif ord == -dpnp.inf:
265-
if row_axis > col_axis:
266-
row_axis -= 1
267-
dpnp_sum_val = dpnp.sum(dpnp.abs(input), axis=col_axis)
268-
ret = dpnp_sum_val.min(axis=row_axis)
269-
elif ord in [None, 'fro', 'f']:
270-
ret = dpnp.sqrt(dpnp.sum(input * input, axis=axis))
271-
# elif ord == 'nuc':
272-
# ret = _multi_svd_norm(input, row_axis, col_axis, sum)
273-
else:
274-
raise ValueError("Invalid norm order for matrices.")
275-
276-
return ret
277-
else:
278-
raise ValueError("Improper number of dimensions to norm.")

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 88 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
dpnp_matrix_power,
5555
dpnp_matrix_rank,
5656
dpnp_multi_dot,
57+
dpnp_norm,
5758
dpnp_pinv,
5859
dpnp_qr,
5960
dpnp_slogdet,
@@ -491,7 +492,7 @@ def multi_dot(arrays, *, out=None):
491492
"""
492493
Compute the dot product of two or more arrays in a single function call.
493494
494-
For full documentation refer to :obj:`numpy.multi_dot`.
495+
For full documentation refer to :obj:`numpy.linalg.multi_dot`.
495496
496497
Parameters
497498
----------
@@ -602,60 +603,109 @@ def pinv(a, rcond=1e-15, hermitian=False):
602603
return dpnp_pinv(a, rcond=rcond, hermitian=hermitian)
603604

604605

605-
def norm(x1, ord=None, axis=None, keepdims=False):
606+
def norm(x, ord=None, axis=None, keepdims=False):
606607
"""
607608
Matrix or vector norm.
608609
609-
This function is able to return one of eight different matrix norms,
610-
or one of an infinite number of vector norms (described below), depending
611-
on the value of the ``ord`` parameter.
610+
For full documentation refer to :obj:`numpy.linalg.norm`.
612611
613612
Parameters
614613
----------
615-
input : array_like
616-
Input array. If `axis` is None, `x` must be 1-D or 2-D, unless `ord`
617-
is None. If both `axis` and `ord` are None, the 2-norm of
618-
``x.ravel`` will be returned.
619-
ord : optional
620-
Order of the norm (see table under ``Notes``). inf means numpy's
621-
`inf` object. The default is None.
622-
axis : optional.
614+
x : {dpnp.ndarray, usm_ndarray}
615+
Input array. If `axis` is ``None``, `x` must be 1-D or 2-D, unless
616+
`ord` is ``None``. If both `axis` and `ord` are ``None``, the 2-norm
617+
of ``x.ravel`` will be returned.
618+
ord : {int, float, inf, -inf, "fro", "nuc"}, optional
619+
Norm type. inf means dpnp's `inf` object. The default is ``None``.
620+
axis : {None, int, 2-tuple of ints}, optional
623621
If `axis` is an integer, it specifies the axis of `x` along which to
624622
compute the vector norms. If `axis` is a 2-tuple, it specifies the
625623
axes that hold 2-D matrices, and the matrix norms of these matrices
626-
are computed. If `axis` is None then either a vector norm (when `x`
627-
is 1-D) or a matrix norm (when `x` is 2-D) is returned. The default
628-
is None.
624+
are computed. If `axis` is ``None`` then either a vector norm (when
625+
`x` is 1-D) or a matrix norm (when `x` is 2-D) is returned.
626+
The default is ``None``.
629627
keepdims : bool, optional
630-
If this is set to True, the axes which are normed over are left in the
631-
result as dimensions with size one. With this option the result will
632-
broadcast correctly against the original `x`.
628+
If this is set to ``True``, the axes which are normed over are left in
629+
the result as dimensions with size one. With this option the result
630+
will broadcast correctly against the original `x`.
633631
634632
Returns
635633
-------
636-
n : float or ndarray
634+
out : dpnp.ndarray
637635
Norm of the matrix or vector(s).
638-
"""
639636
640-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
641-
if x1_desc:
642-
if (
643-
not isinstance(axis, int)
644-
and not isinstance(axis, tuple)
645-
and axis is not None
646-
):
647-
pass
648-
elif keepdims is not False:
649-
pass
650-
elif ord not in [None, 0, 3, "fro", "f"]:
651-
pass
652-
else:
653-
result_obj = dpnp_norm(x1, ord=ord, axis=axis)
654-
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
637+
Examples
638+
--------
639+
>>> import dpnp as np
640+
>>> a = np.arange(9) - 4
641+
>>> a
642+
array([-4, -3, -2, -1, 0, 1, 2, 3, 4])
643+
>>> b = a.reshape((3, 3))
644+
>>> b
645+
array([[-4, -3, -2],
646+
[-1, 0, 1],
647+
[ 2, 3, 4]])
648+
649+
>>> np.linalg.norm(a)
650+
array(7.745966692414834)
651+
>>> np.linalg.norm(b)
652+
array(7.745966692414834)
653+
>>> np.linalg.norm(b, 'fro')
654+
array(7.745966692414834)
655+
>>> np.linalg.norm(a, np.inf)
656+
array(4.)
657+
>>> np.linalg.norm(b, np.inf)
658+
array(9.)
659+
>>> np.linalg.norm(a, -np.inf)
660+
array(0.)
661+
>>> np.linalg.norm(b, -np.inf)
662+
array(2.)
663+
664+
>>> np.linalg.norm(a, 1)
665+
array(20.)
666+
>>> np.linalg.norm(b, 1)
667+
array(7.)
668+
>>> np.linalg.norm(a, -1)
669+
array(0.)
670+
>>> np.linalg.norm(b, -1)
671+
array(6.)
672+
>>> np.linalg.norm(a, 2)
673+
array(7.745966692414834)
674+
>>> np.linalg.norm(b, 2)
675+
array(7.3484692283495345)
676+
677+
>>> np.linalg.norm(a, -2)
678+
array(0.)
679+
>>> np.linalg.norm(b, -2)
680+
array(1.8570331885190563e-016) # may vary
681+
>>> np.linalg.norm(a, 3)
682+
array(5.8480354764257312) # may vary
683+
>>> np.linalg.norm(a, -3)
684+
array(0.)
685+
686+
Using the `axis` argument to compute vector norms:
687+
688+
>>> c = np.array([[ 1, 2, 3],
689+
... [-1, 1, 4]])
690+
>>> np.linalg.norm(c, axis=0)
691+
array([ 1.41421356, 2.23606798, 5. ])
692+
>>> np.linalg.norm(c, axis=1)
693+
array([ 3.74165739, 4.24264069])
694+
>>> np.linalg.norm(c, ord=1, axis=1)
695+
array([ 6., 6.])
696+
697+
Using the `axis` argument to compute matrix norms:
698+
699+
>>> m = np.arange(8).reshape(2,2,2)
700+
>>> np.linalg.norm(m, axis=(1,2))
701+
array([ 3.74165739, 11.22497216])
702+
>>> np.linalg.norm(m[0, :, :]), np.linalg.norm(m[1, :, :])
703+
(array(3.7416573867739413), array(11.224972160321824))
655704
656-
return result
705+
"""
657706

658-
return call_origin(numpy.linalg.norm, x1, ord, axis, keepdims)
707+
dpnp.check_supported_arrays_type(x)
708+
return dpnp_norm(x, ord, axis, keepdims)
659709

660710

661711
def qr(a, mode="reduced"):

0 commit comments

Comments
 (0)