Skip to content

Commit 63052a6

Browse files
ENH: Add lazy copy for take and between_time (#50476)
Co-authored-by: Joris Van den Bossche <jorisvandenbossche@gmail.com>
1 parent 9d1f72b commit 63052a6

File tree

3 files changed

+59
-1
lines changed

3 files changed

+59
-1
lines changed

pandas/core/generic.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
)
3636

3737
from pandas._libs import lib
38+
from pandas._libs.lib import array_equal_fast
3839
from pandas._libs.tslibs import (
3940
Period,
4041
Tick,
@@ -3778,6 +3779,18 @@ def _take(
37783779
37793780
See the docstring of `take` for full explanation of the parameters.
37803781
"""
3782+
if not isinstance(indices, slice):
3783+
indices = np.asarray(indices, dtype=np.intp)
3784+
if (
3785+
axis == 0
3786+
and indices.ndim == 1
3787+
and using_copy_on_write()
3788+
and array_equal_fast(
3789+
indices,
3790+
np.arange(0, len(self), dtype=np.intp),
3791+
)
3792+
):
3793+
return self.copy(deep=None)
37813794

37823795
new_data = self._mgr.take(
37833796
indices,

pandas/core/series.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@
3232
properties,
3333
reshape,
3434
)
35-
from pandas._libs.lib import no_default
35+
from pandas._libs.lib import (
36+
array_equal_fast,
37+
no_default,
38+
)
3639
from pandas._typing import (
3740
AggFuncType,
3841
AlignJoin,
@@ -879,6 +882,14 @@ def take(self, indices, axis: Axis = 0, **kwargs) -> Series:
879882
nv.validate_take((), kwargs)
880883

881884
indices = ensure_platform_int(indices)
885+
886+
if (
887+
indices.ndim == 1
888+
and using_copy_on_write()
889+
and array_equal_fast(indices, np.arange(0, len(self), dtype=indices.dtype))
890+
):
891+
return self.copy(deep=None)
892+
882893
new_index = self.index.take(indices)
883894
new_values = self._values.take(indices)
884895

pandas/tests/copy_view/test_methods.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,40 @@ def test_assign_drop_duplicates(using_copy_on_write, method):
596596
tm.assert_frame_equal(df, df_orig)
597597

598598

599+
@pytest.mark.parametrize("obj", [Series([1, 2]), DataFrame({"a": [1, 2]})])
600+
def test_take(using_copy_on_write, obj):
601+
# Check that no copy is made when we take all rows in original order
602+
obj_orig = obj.copy()
603+
obj2 = obj.take([0, 1])
604+
605+
if using_copy_on_write:
606+
assert np.shares_memory(obj2.values, obj.values)
607+
else:
608+
assert not np.shares_memory(obj2.values, obj.values)
609+
610+
obj2.iloc[0] = 0
611+
if using_copy_on_write:
612+
assert not np.shares_memory(obj2.values, obj.values)
613+
tm.assert_equal(obj, obj_orig)
614+
615+
616+
@pytest.mark.parametrize("obj", [Series([1, 2]), DataFrame({"a": [1, 2]})])
617+
def test_between_time(using_copy_on_write, obj):
618+
obj.index = date_range("2018-04-09", periods=2, freq="1D20min")
619+
obj_orig = obj.copy()
620+
obj2 = obj.between_time("0:00", "1:00")
621+
622+
if using_copy_on_write:
623+
assert np.shares_memory(obj2.values, obj.values)
624+
else:
625+
assert not np.shares_memory(obj2.values, obj.values)
626+
627+
obj2.iloc[0] = 0
628+
if using_copy_on_write:
629+
assert not np.shares_memory(obj2.values, obj.values)
630+
tm.assert_equal(obj, obj_orig)
631+
632+
599633
def test_reindex_like(using_copy_on_write):
600634
df = DataFrame({"a": [1, 2], "b": "a"})
601635
other = DataFrame({"b": "a", "a": [1, 2]})

0 commit comments

Comments
 (0)