Skip to content

Commit d1ed10f

Browse files
committed
use replace_nan func for nanprod
1 parent fa45337 commit d1ed10f

File tree

3 files changed

+9
-19
lines changed

3 files changed

+9
-19
lines changed

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,6 @@ enum class DPNPFuncName : size_t
216216
DPNP_FN_MULTIPLY_EXT, /**< Used in numpy.multiply() impl, requires extra
217217
parameters */
218218
DPNP_FN_NANVAR, /**< Used in numpy.nanvar() impl */
219-
DPNP_FN_NANVAR_EXT, /**< Used in numpy.nanvar() impl, requires extra
220-
parameters */
221219
DPNP_FN_NEGATIVE, /**< Used in numpy.negative() impl */
222220
DPNP_FN_NONZERO, /**< Used in numpy.nonzero() impl */
223221
DPNP_FN_ONES, /**< Used in numpy.ones() impl */

dpnp/dpnp_iface.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -465,11 +465,12 @@ def get_result_array(a, out=None, casting="safe"):
465465
----------
466466
a : {dpnp_array}
467467
Input array.
468-
469468
out : {dpnp_array, usm_ndarray}
470469
If provided, value of `a` array will be copied into it
471470
according to ``safe`` casting rule.
472471
It should be of the appropriate shape.
472+
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
473+
Controls what kind of data casting may occur.
473474
474475
Returns
475476
-------
@@ -625,16 +626,12 @@ def _replace_nan(a, val):
625626
NaNs, otherwise return None.
626627
"""
627628

628-
if dpnp.is_supported_array_or_scalar(a):
629-
if issubclass(a.dtype.type, dpnp.inexact):
630-
mask = dpnp.isnan(a)
631-
a = dpnp.array(a, copy=True)
632-
dpnp.copyto(a, val, where=mask)
633-
else:
634-
mask = None
629+
dpnp.check_supported_arrays_type(a)
630+
if issubclass(a.dtype.type, dpnp.inexact):
631+
mask = dpnp.isnan(a)
632+
a = dpnp.array(a, copy=True)
633+
dpnp.copyto(a, val, where=mask)
635634
else:
636-
raise TypeError(
637-
"An array must be any of supported type, but got {}".format(type(a))
638-
)
635+
mask = None
639636

640637
return a, mask

dpnp/dpnp_iface_mathematical.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1783,12 +1783,7 @@ def nanprod(
17831783
17841784
"""
17851785

1786-
dpnp.check_supported_arrays_type(a)
1787-
1788-
if issubclass(a.dtype.type, dpnp.inexact):
1789-
mask = dpnp.isnan(a)
1790-
a = dpnp.array(a, copy=True)
1791-
dpnp.copyto(a, 1, where=mask)
1786+
a, mask = dpnp._replace_nan(a, 1)
17921787

17931788
return dpnp.prod(
17941789
a,

0 commit comments

Comments
 (0)