Skip to content

CLN: EWMA cython code and function dispatch #34636

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions pandas/_libs/window/aggregations.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1793,19 +1793,19 @@ def ewma(float64_t[:] vals, float64_t com, int adjust, bint ignore_na, int minp)
new_wt = 1. if adjust else alpha

weighted_avg = vals[0]
is_observation = (weighted_avg == weighted_avg)
is_observation = weighted_avg == weighted_avg
nobs = int(is_observation)
output[0] = weighted_avg if (nobs >= minp) else NaN
output[0] = weighted_avg if nobs >= minp else NaN
old_wt = 1.

with nogil:
for i in range(1, N):
cur = vals[i]
is_observation = (cur == cur)
is_observation = cur == cur
nobs += is_observation
if weighted_avg == weighted_avg:

if is_observation or (not ignore_na):
if is_observation or not ignore_na:

old_wt *= old_wt_factor
if is_observation:
Expand All @@ -1821,7 +1821,7 @@ def ewma(float64_t[:] vals, float64_t com, int adjust, bint ignore_na, int minp)
elif is_observation:
weighted_avg = cur

output[i] = weighted_avg if (nobs >= minp) else NaN
output[i] = weighted_avg if nobs >= minp else NaN

return output

Expand Down Expand Up @@ -1851,16 +1851,16 @@ def ewmcov(float64_t[:] input_x, float64_t[:] input_y,
"""

cdef:
Py_ssize_t N = len(input_x)
Py_ssize_t N = len(input_x), M = len(input_y)
float64_t alpha, old_wt_factor, new_wt, mean_x, mean_y, cov
float64_t sum_wt, sum_wt2, old_wt, cur_x, cur_y, old_mean_x, old_mean_y
float64_t numerator, denominator
Py_ssize_t i, nobs
ndarray[float64_t] output
bint is_observation

if <Py_ssize_t>len(input_y) != N:
raise ValueError(f"arrays are of different lengths ({N} and {len(input_y)})")
if M != N:
raise ValueError(f"arrays are of different lengths ({N} and {M})")

output = np.empty(N, dtype=float)
if N == 0:
Expand All @@ -1874,12 +1874,12 @@ def ewmcov(float64_t[:] input_x, float64_t[:] input_y,

mean_x = input_x[0]
mean_y = input_y[0]
is_observation = ((mean_x == mean_x) and (mean_y == mean_y))
is_observation = (mean_x == mean_x) and (mean_y == mean_y)
nobs = int(is_observation)
if not is_observation:
mean_x = NaN
mean_y = NaN
output[0] = (0. if bias else NaN) if (nobs >= minp) else NaN
output[0] = (0. if bias else NaN) if nobs >= minp else NaN
cov = 0.
sum_wt = 1.
sum_wt2 = 1.
Expand All @@ -1890,10 +1890,10 @@ def ewmcov(float64_t[:] input_x, float64_t[:] input_y,
for i in range(1, N):
cur_x = input_x[i]
cur_y = input_y[i]
is_observation = ((cur_x == cur_x) and (cur_y == cur_y))
is_observation = (cur_x == cur_x) and (cur_y == cur_y)
nobs += is_observation
if mean_x == mean_x:
if is_observation or (not ignore_na):
if is_observation or not ignore_na:
sum_wt *= old_wt_factor
sum_wt2 *= (old_wt_factor * old_wt_factor)
old_wt *= old_wt_factor
Expand Down Expand Up @@ -1929,8 +1929,8 @@ def ewmcov(float64_t[:] input_x, float64_t[:] input_y,
if not bias:
numerator = sum_wt * sum_wt
denominator = numerator - sum_wt2
if (denominator > 0.):
output[i] = ((numerator / denominator) * cov)
if denominator > 0:
output[i] = (numerator / denominator) * cov
else:
output[i] = NaN
else:
Expand Down
32 changes: 12 additions & 20 deletions pandas/core/window/ewm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from textwrap import dedent

import numpy as np
Expand Down Expand Up @@ -219,7 +220,7 @@ def aggregate(self, func, *args, **kwargs):

agg = aggregate

def _apply(self, func, **kwargs):
def _apply(self, func):
"""
Rolling statistical measure using supplied function. Designed to be
used with passed-in Cython array-based functions.
Expand Down Expand Up @@ -253,23 +254,6 @@ def _apply(self, func, **kwargs):
results.append(values.copy())
continue

# if we have a string function name, wrap it
if isinstance(func, str):
cfunc = getattr(window_aggregations, func, None)
if cfunc is None:
raise ValueError(
f"we do not support this function in window_aggregations.{func}"
)

def func(arg):
return cfunc(
arg,
self.com,
int(self.adjust),
int(self.ignore_na),
int(self.min_periods),
)

results.append(np.apply_along_axis(func, self.axis, values))

return self._wrap_results(results, block_list, obj, exclude)
Expand All @@ -286,7 +270,15 @@ def mean(self, *args, **kwargs):
Arguments and keyword arguments to be passed into func.
"""
nv.validate_window_func("mean", args, kwargs)
return self._apply("ewma", **kwargs)
window_func = self._get_roll_func("ewma")
window_func = partial(
window_func,
com=self.com,
adjust=int(self.adjust),
ignore_na=self.ignore_na,
minp=int(self.min_periods),
)
return self._apply(window_func)

@Substitution(name="ewm", func_name="std")
@Appender(_doc_template)
Expand Down Expand Up @@ -320,7 +312,7 @@ def f(arg):
int(bias),
)

return self._apply(f, **kwargs)
return self._apply(f)

@Substitution(name="ewm", func_name="cov")
@Appender(_doc_template)
Expand Down