Skip to content

REF: Use numpy set methods in interpolate #57997

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 25 additions & 42 deletions pandas/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,20 +471,20 @@ def _interpolate_1d(
if valid.all():
return

# These are sets of index pointers to invalid values... i.e. {0, 1, etc...
all_nans = set(np.flatnonzero(invalid))
# These index pointers to invalid values... i.e. {0, 1, etc...
all_nans = np.flatnonzero(invalid)

first_valid_index = find_valid_index(how="first", is_valid=valid)
if first_valid_index is None: # no nan found in start
first_valid_index = 0
start_nans = set(range(first_valid_index))
start_nans = np.arange(first_valid_index)

last_valid_index = find_valid_index(how="last", is_valid=valid)
if last_valid_index is None: # no nan found in end
last_valid_index = len(yvalues)
end_nans = set(range(1 + last_valid_index, len(valid)))
end_nans = np.arange(1 + last_valid_index, len(valid))

# Like the sets above, preserve_nans contains indices of invalid values,
# preserve_nans contains indices of invalid values,
# but in this case, it is the final set of indices that need to be
# preserved as NaN after the interpolation.

Expand All @@ -493,27 +493,25 @@ def _interpolate_1d(
# are more than 'limit' away from the prior non-NaN.

# set preserve_nans based on direction using _interp_limit
preserve_nans: list | set
if limit_direction == "forward":
preserve_nans = start_nans | set(_interp_limit(invalid, limit, 0))
preserve_nans = np.union1d(start_nans, _interp_limit(invalid, limit, 0))
elif limit_direction == "backward":
preserve_nans = end_nans | set(_interp_limit(invalid, 0, limit))
preserve_nans = np.union1d(end_nans, _interp_limit(invalid, 0, limit))
else:
# both directions... just use _interp_limit
preserve_nans = set(_interp_limit(invalid, limit, limit))
preserve_nans = np.unique(_interp_limit(invalid, limit, limit))

# if limit_area is set, add either mid or outside indices
# to preserve_nans GH #16284
if limit_area == "inside":
# preserve NaNs on the outside
preserve_nans |= start_nans | end_nans
preserve_nans = np.union1d(preserve_nans, start_nans)
preserve_nans = np.union1d(preserve_nans, end_nans)
elif limit_area == "outside":
# preserve NaNs on the inside
mid_nans = all_nans - start_nans - end_nans
preserve_nans |= mid_nans

# sort preserve_nans and convert to list
preserve_nans = sorted(preserve_nans)
mid_nans = np.setdiff1d(all_nans, start_nans, assume_unique=True)
mid_nans = np.setdiff1d(mid_nans, end_nans, assume_unique=True)
preserve_nans = np.union1d(preserve_nans, mid_nans)

is_datetimelike = yvalues.dtype.kind in "mM"

Expand Down Expand Up @@ -1027,7 +1025,7 @@ def clean_reindex_fill_method(method) -> ReindexMethod | None:

def _interp_limit(
invalid: npt.NDArray[np.bool_], fw_limit: int | None, bw_limit: int | None
):
) -> np.ndarray:
"""
Get indexers of values that won't be filled
because they exceed the limits.
Expand Down Expand Up @@ -1059,20 +1057,23 @@ def _interp_limit(invalid, fw_limit, bw_limit):
# 1. operate on the reversed array
# 2. subtract the returned indices from N - 1
N = len(invalid)
f_idx = set()
b_idx = set()
f_idx = np.array([], dtype=np.int64)
b_idx = np.array([], dtype=np.int64)
assume_unique = True

def inner(invalid, limit: int):
limit = min(limit, N)
windowed = _rolling_window(invalid, limit + 1).all(1)
idx = set(np.where(windowed)[0] + limit) | set(
np.where((~invalid[: limit + 1]).cumsum() == 0)[0]
windowed = np.lib.stride_tricks.sliding_window_view(invalid, limit + 1).all(1)
idx = np.union1d(
np.where(windowed)[0] + limit,
np.where((~invalid[: limit + 1]).cumsum() == 0)[0],
)
return idx

if fw_limit is not None:
if fw_limit == 0:
f_idx = set(np.where(invalid)[0])
f_idx = np.where(invalid)[0]
assume_unique = False
else:
f_idx = inner(invalid, fw_limit)

Expand All @@ -1082,26 +1083,8 @@ def inner(invalid, limit: int):
# just use forwards
return f_idx
else:
b_idx_inv = list(inner(invalid[::-1], bw_limit))
b_idx = set(N - 1 - np.asarray(b_idx_inv))
b_idx = N - 1 - inner(invalid[::-1], bw_limit)
if fw_limit == 0:
return b_idx

return f_idx & b_idx


def _rolling_window(a: npt.NDArray[np.bool_], window: int) -> npt.NDArray[np.bool_]:
"""
[True, True, False, True, False], 2 ->

[
[True, True],
[True, False],
[False, True],
[True, False],
]
"""
# https://stackoverflow.com/a/6811241
shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
strides = a.strides + (a.strides[-1],)
return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
return np.intersect1d(f_idx, b_idx, assume_unique=assume_unique)