Skip to content

Commit b72f953

Browse files
authored
Add mean keyword to dpnp.std and dpnp.var (#2271)
The PR proposes to add `mean` keyword argument to `dpnp.std`, `dpnp.var`, `dpnp.nanstd`, `dpnp.nanvar` functions and `dpnp.ndarray.std`, `dpnp.ndarray.var` methods. The keyword was introduced by NumPy 2.0 and intended to improve the performance: > Often when the standard deviation is needed the mean is also needed; the same holds for the variance and the mean. With the current code the mean is then calculated twice, this can be prevented if the functions calculating the variance or the standard deviation can use a precalculated mean This PR implements similar improvement in dpnp code.
1 parent d509d3f commit b72f953

File tree

7 files changed

+718
-584
lines changed

7 files changed

+718
-584
lines changed

.github/workflows/build-sphinx.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ jobs:
223223
PR_NUM: ${{ github.event.number }}
224224
uses: mshick/add-pr-comment@b8f338c590a895d50bcbfa6c5859251edc8952fc # v2.8.2
225225
with:
226+
message-id: url_to_docs
226227
message: |
227228
View rendered docs @ https://intelpython.github.io/dpnp/pull/${{ env.PR_NUM }}/index.html
228229
allow-repeats: false

.github/workflows/conda-package.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,9 +600,9 @@ jobs:
600600
if: ${{ github.event.pull_request && !github.event.pull_request.head.repo.fork }}
601601
uses: mshick/add-pr-comment@b8f338c590a895d50bcbfa6c5859251edc8952fc # v2.8.2
602602
with:
603+
message-id: array_api_results
603604
message: |
604605
${{ env.MESSAGE }}
605-
refresh-message-position: true
606606
607607
cleanup_packages:
608608
name: Clean up anaconda packages

dpnp/dpnp_array.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1731,6 +1731,7 @@ def std(
17311731
keepdims=False,
17321732
*,
17331733
where=True,
1734+
mean=None,
17341735
):
17351736
"""
17361737
Returns the standard deviation of the array elements, along given axis.
@@ -1739,7 +1740,9 @@ def std(
17391740
17401741
"""
17411742

1742-
return dpnp.std(self, axis, dtype, out, ddof, keepdims, where=where)
1743+
return dpnp.std(
1744+
self, axis, dtype, out, ddof, keepdims, where=where, mean=mean
1745+
)
17431746

17441747
@property
17451748
def strides(self):
@@ -1938,6 +1941,7 @@ def var(
19381941
keepdims=False,
19391942
*,
19401943
where=True,
1944+
mean=None,
19411945
):
19421946
"""
19431947
Returns the variance of the array elements, along given axis.
@@ -1946,7 +1950,9 @@ def var(
19461950
19471951
"""
19481952

1949-
return dpnp.var(self, axis, dtype, out, ddof, keepdims, where=where)
1953+
return dpnp.var(
1954+
self, axis, dtype, out, ddof, keepdims, where=where, mean=mean
1955+
)
19501956

19511957

19521958
# 'view'

dpnp/dpnp_iface_nanfunctions.py

Lines changed: 141 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
3838
"""
3939

40+
# pylint: disable=duplicate-code
41+
4042
import warnings
4143

4244
import dpnp
@@ -955,7 +957,15 @@ def nansum(
955957

956958

957959
def nanstd(
958-
a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, where=True
960+
a,
961+
axis=None,
962+
dtype=None,
963+
out=None,
964+
ddof=0,
965+
keepdims=False,
966+
*,
967+
where=True,
968+
mean=None,
959969
):
960970
"""
961971
Compute the standard deviation along the specified axis,
@@ -969,40 +979,52 @@ def nanstd(
969979
Input array.
970980
axis : {None, int, tuple of ints}, optional
971981
Axis or axes along which the standard deviations must be computed.
972-
If a tuple of unique integers is given, the standard deviations
973-
are computed over multiple axes. If ``None``, the standard deviation
974-
is computed over the entire array.
982+
If a tuple of unique integers is given, the standard deviations are
983+
computed over multiple axes. If ``None``, the standard deviation is
984+
computed over the entire array.
985+
975986
Default: ``None``.
976987
dtype : {None, dtype}, optional
977-
Type to use in computing the standard deviation. By default,
978-
if `a` has a floating-point data type, the returned array
979-
will have the same data type as `a`.
980-
If `a` has a boolean or integral data type, the returned array
981-
will have the default floating point data type for the device
988+
Type to use in computing the standard deviation. By default, if `a` has
989+
a floating-point data type, the returned array will have the same data
990+
type as `a`. If `a` has a boolean or integral data type, the returned
991+
array will have the default floating point data type for the device
982992
where input array `a` is allocated.
993+
994+
Default: ``None``.
983995
out : {None, dpnp.ndarray, usm_ndarray}, optional
984996
Alternative output array in which to place the result. It must have
985997
the same shape as the expected output but the type (of the calculated
986998
values) will be cast if necessary.
999+
1000+
Default: ``None``.
9871001
ddof : {int, float}, optional
988-
Means Delta Degrees of Freedom. The divisor used in calculations
989-
is ``N - ddof``, where ``N`` the number of non-NaN elements.
990-
Default: `0.0`.
1002+
Means Delta Degrees of Freedom. The divisor used in calculations is
1003+
``N - ddof``, where ``N`` the number of non-NaN elements.
1004+
1005+
Default: ``0.0``.
9911006
keepdims : {None, bool}, optional
9921007
If ``True``, the reduced axes (dimensions) are included in the result
993-
as singleton dimensions, so that the returned array remains
994-
compatible with the input array according to Array Broadcasting
995-
rules. Otherwise, if ``False``, the reduced axes are not included in
996-
the returned array. Default: ``False``.
1008+
as singleton dimensions, so that the returned array remains compatible
1009+
with the input array according to Array Broadcasting rules. Otherwise,
1010+
if ``False``, the reduced axes are not included in the returned array.
1011+
1012+
Default: ``False``.
1013+
mean : {dpnp.ndarray, usm_ndarray}, optional
1014+
Provide the mean to prevent its recalculation. The mean should have
1015+
a shape as if it was calculated with ``keepdims=True``.
1016+
The axis for the calculation of the mean should be the same as used in
1017+
the call to this `nanstd` function.
1018+
1019+
Default: ``None``.
9971020
9981021
Returns
9991022
-------
10001023
out : dpnp.ndarray
1001-
An array containing the standard deviations. If the standard
1002-
deviation was computed over the entire array, a zero-dimensional
1003-
array is returned. If `ddof` is >= the number of non-NaN elements
1004-
in a slice or the slice contains only NaNs, then the result for
1005-
that slice is NaN.
1024+
An array containing the standard deviations. If the standard deviation
1025+
was computed over the entire array, a zero-dimensional array is
1026+
returned. If `ddof` is >= the number of non-NaN elements in a slice or
1027+
the slice contains only NaNs, then the result for that slice is NaN.
10061028
10071029
Limitations
10081030
-----------
@@ -1011,6 +1033,19 @@ def nanstd(
10111033
10121034
Notes
10131035
-----
1036+
The standard deviation is the square root of the average of the squared
1037+
deviations from the mean: ``std = sqrt(mean(abs(x - x.mean())**2))``.
1038+
1039+
The average squared deviation is normally calculated as ``x.sum() / N``,
1040+
where ``N = len(x)``. If, however, `ddof` is specified, the divisor
1041+
``N - ddof`` is used instead. In standard statistical practice, ``ddof=1``
1042+
provides an unbiased estimator of the variance of the infinite population.
1043+
``ddof=0`` provides a maximum likelihood estimate of the variance for
1044+
normally distributed variables.
1045+
The standard deviation computed in this function is the square root of
1046+
the estimated variance, so even with ``ddof=1``, it will not be an unbiased
1047+
estimate of the standard deviation per se.
1048+
10141049
Note that, for complex numbers, the absolute value is taken before
10151050
squaring, so that the result is always real and non-negative.
10161051
@@ -1029,11 +1064,18 @@ def nanstd(
10291064
>>> import dpnp as np
10301065
>>> a = np.array([[1, np.nan], [3, 4]])
10311066
>>> np.nanstd(a)
1032-
array(1.247219128924647)
1067+
array(1.24721913)
10331068
>>> np.nanstd(a, axis=0)
1034-
array([1., 0.])
1069+
array([1., 0.])
10351070
>>> np.nanstd(a, axis=1)
1036-
array([0., 0.5]) # may vary
1071+
array([0. , 0.5]) # may vary
1072+
1073+
Using the mean keyword to save computation time:
1074+
1075+
>>> a = np.array([[14, 8, np.nan, 10], [7, 9, 10, 11], [np.nan, 15, 5, 10]])
1076+
>>> mean = np.nanmean(a, axis=1, keepdims=True)
1077+
>>> np.nanstd(a, axis=1, mean=mean)
1078+
array([2.49443826, 1.47901995, 4.0824829 ])
10371079
10381080
"""
10391081

@@ -1051,13 +1093,21 @@ def nanstd(
10511093
ddof=ddof,
10521094
keepdims=keepdims,
10531095
where=where,
1096+
mean=mean,
10541097
)
1055-
dpnp.sqrt(res, out=res)
1056-
return res
1098+
return dpnp.sqrt(res, out=res)
10571099

10581100

10591101
def nanvar(
1060-
a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, where=True
1102+
a,
1103+
axis=None,
1104+
dtype=None,
1105+
out=None,
1106+
ddof=0,
1107+
keepdims=False,
1108+
*,
1109+
where=True,
1110+
mean=None,
10611111
):
10621112
"""
10631113
Compute the variance along the specified axis, while ignoring NaNs.
@@ -1069,39 +1119,52 @@ def nanvar(
10691119
a : {dpnp.ndarray, usm_ndarray}
10701120
Input array.
10711121
axis : {None, int, tuple of ints}, optional
1072-
axis or axes along which the variances must be computed. If a tuple
1122+
Axis or axes along which the variances must be computed. If a tuple
10731123
of unique integers is given, the variances are computed over multiple
10741124
axes. If ``None``, the variance is computed over the entire array.
1125+
10751126
Default: ``None``.
10761127
dtype : {None, dtype}, optional
10771128
Type to use in computing the variance. By default, if `a` has a
10781129
floating-point data type, the returned array will have
1079-
the same data type as `a`.
1080-
If `a` has a boolean or integral data type, the returned array
1081-
will have the default floating point data type for the device
1082-
where input array `a` is allocated.
1130+
the same data type as `a`. If `a` has a boolean or integral data type,
1131+
the returned array will have the default floating point data type for
1132+
the device where input array `a` is allocated.
1133+
1134+
Default: ``None``.
10831135
out : {None, dpnp.ndarray, usm_ndarray}, optional
10841136
Alternative output array in which to place the result. It must have
10851137
the same shape as the expected output but the type (of the calculated
10861138
values) will be cast if necessary.
1139+
1140+
Default: ``None``.
10871141
ddof : {int, float}, optional
1088-
Means Delta Degrees of Freedom. The divisor used in calculations
1089-
is ``N - ddof``, where ``N`` represents the number of non-NaN elements.
1090-
Default: `0.0`.
1142+
Means Delta Degrees of Freedom. The divisor used in calculations is
1143+
``N - ddof``, where ``N`` represents the number of non-NaN elements.
1144+
1145+
Default: ``0.0``.
10911146
keepdims : {None, bool}, optional
10921147
If ``True``, the reduced axes (dimensions) are included in the result
1093-
as singleton dimensions, so that the returned array remains
1094-
compatible with the input array according to Array Broadcasting
1095-
rules. Otherwise, if ``False``, the reduced axes are not included in
1096-
the returned array. Default: ``False``.
1148+
as singleton dimensions, so that the returned array remains compatible
1149+
with the input array according to Array Broadcasting rules. Otherwise,
1150+
if ``False``, the reduced axes are not included in the returned array.
1151+
1152+
Default: ``False``.
1153+
mean : {dpnp.ndarray, usm_ndarray}, optional
1154+
Provide the mean to prevent its recalculation. The mean should have
1155+
a shape as if it was calculated with ``keepdims=True``.
1156+
The axis for the calculation of the mean should be the same as used in
1157+
the call to this `nanvar` function.
1158+
1159+
Default: ``None``.
10971160
10981161
Returns
10991162
-------
11001163
out : dpnp.ndarray
1101-
An array containing the variances. If the variance was computed
1102-
over the entire array, a zero-dimensional array is returned.
1103-
If `ddof` is >= the number of non-NaN elements in a slice or the
1104-
slice contains only NaNs, then the result for that slice is NaN.
1164+
An array containing the variances. If the variance was computed over
1165+
the entire array, a zero-dimensional array is returned. If `ddof` is >=
1166+
the number of non-NaN elements in a slice or the slice contains only
1167+
NaNs, then the result for that slice is NaN.
11051168
11061169
Limitations
11071170
-----------
@@ -1110,6 +1173,16 @@ def nanvar(
11101173
11111174
Notes
11121175
-----
1176+
The variance is the average of the squared deviations from the mean,
1177+
that is ``var = mean(abs(x - x.mean())**2)``.
1178+
1179+
The mean is normally calculated as ``x.sum() / N``, where ``N = len(x)``.
1180+
If, however, `ddof` is specified, the divisor ``N - ddof`` is used instead.
1181+
In standard statistical practice, ``ddof=1`` provides an unbiased estimator
1182+
of the variance of a hypothetical infinite population. ``ddof=0`` provides
1183+
a maximum likelihood estimate of the variance for normally distributed
1184+
variables.
1185+
11131186
Note that, for complex numbers, the absolute value is taken before squaring,
11141187
so that the result is always real and non-negative.
11151188
@@ -1127,11 +1200,18 @@ def nanvar(
11271200
>>> import dpnp as np
11281201
>>> a = np.array([[1, np.nan], [3, 4]])
11291202
>>> np.nanvar(a)
1130-
array(1.5555555555555554)
1203+
array(1.55555556)
11311204
>>> np.nanvar(a, axis=0)
1132-
array([1., 0.])
1205+
array([1., 0.])
11331206
>>> np.nanvar(a, axis=1)
1134-
array([0., 0.25]) # may vary
1207+
array([0. , 0.25]) # may vary
1208+
1209+
Using the mean keyword to save computation time:
1210+
1211+
>>> a = np.array([[14, 8, np.nan, 10], [7, 9, 10, 11], [np.nan, 15, 5, 10]])
1212+
>>> mean = np.nanmean(a, axis=1, keepdims=True)
1213+
>>> np.nanvar(a, axis=1, mean=mean)
1214+
array([ 6.22222222, 2.1875 , 16.66666667])
11351215
11361216
"""
11371217

@@ -1157,46 +1237,51 @@ def nanvar(
11571237
dtype = dpnp.dtype(dtype)
11581238
if not dpnp.issubdtype(dtype, dpnp.inexact):
11591239
raise TypeError("If input is inexact, then dtype must be inexact.")
1240+
11601241
if out is not None:
11611242
dpnp.check_supported_arrays_type(out)
11621243
if not dpnp.issubdtype(out.dtype, dpnp.inexact):
11631244
raise TypeError("If input is inexact, then out must be inexact.")
11641245

11651246
# Compute mean
1166-
var_dtype = a.real.dtype if dtype is None else dtype
11671247
cnt = dpnp.sum(
1168-
~mask, axis=axis, dtype=var_dtype, keepdims=True, where=where
1248+
~mask, axis=axis, dtype=dpnp.intp, keepdims=True, where=where
11691249
)
1170-
avg = dpnp.sum(arr, axis=axis, dtype=dtype, keepdims=True, where=where)
1171-
avg = dpnp.divide(avg, cnt, out=avg)
11721250

1173-
# Compute squared deviation from mean.
1251+
if mean is not None:
1252+
avg = mean
1253+
else:
1254+
avg = dpnp.sum(arr, axis=axis, dtype=dtype, keepdims=True, where=where)
1255+
avg = dpnp.divide(avg, cnt, out=avg)
1256+
1257+
# Compute squared deviation from mean
11741258
if arr.dtype == avg.dtype:
11751259
arr = dpnp.subtract(arr, avg, out=arr)
11761260
else:
11771261
arr = dpnp.subtract(arr, avg)
11781262
dpnp.copyto(arr, 0.0, where=mask)
1263+
11791264
if dpnp.issubdtype(arr.dtype, dpnp.complexfloating):
11801265
sqr = dpnp.multiply(arr, arr.conj(), out=arr).real
11811266
else:
1182-
sqr = dpnp.multiply(arr, arr, out=arr)
1267+
sqr = dpnp.square(arr, out=arr)
11831268

11841269
# Compute variance
11851270
var = dpnp.sum(
11861271
sqr,
11871272
axis=axis,
1188-
dtype=var_dtype,
1273+
dtype=dtype,
11891274
out=out,
11901275
keepdims=keepdims,
11911276
where=where,
11921277
)
11931278

11941279
if var.ndim < cnt.ndim:
11951280
cnt = cnt.squeeze(axis)
1196-
cnt -= ddof
1197-
dpnp.divide(var, cnt, out=var)
1281+
dof = cnt - ddof
1282+
dpnp.divide(var, dof, out=var)
11981283

1199-
isbad = cnt <= 0
1284+
isbad = dof <= 0
12001285
if dpnp.any(isbad):
12011286
# NaN, inf, or negative numbers are all possible bad
12021287
# values, so explicitly replace them with NaN.

0 commit comments

Comments
 (0)