Skip to content

Commit f3fce88

Browse files
committed
ENH: Add CoW optimization for fillna
1 parent 3e8c3b0 commit f3fce88

File tree

4 files changed

+129
-18
lines changed

4 files changed

+129
-18
lines changed

doc/source/whatsnew/v2.0.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ Copy-on-Write improvements
223223
- :meth:`DataFrame.to_period` / :meth:`Series.to_period`
224224
- :meth:`DataFrame.truncate`
225225
- :meth:`DataFrame.tz_convert` / :meth:`Series.tz_localize`
226+
- :meth:`DataFrame.fillna` / :meth:`Series.fillna`
226227
- :meth:`DataFrame.infer_objects` / :meth:`Series.infer_objects`
227228
- :meth:`DataFrame.astype` / :meth:`Series.astype`
228229
- :func:`concat`

pandas/core/internals/blocks.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,9 @@ def coerce_to_target_dtype(self, other) -> Block:
421421
return self.astype(new_dtype, copy=False)
422422

423423
@final
424-
def _maybe_downcast(self, blocks: list[Block], downcast=None) -> list[Block]:
424+
def _maybe_downcast(
425+
self, blocks: list[Block], downcast=None, using_cow: bool = False
426+
) -> list[Block]:
425427
if downcast is False:
426428
return blocks
427429

@@ -431,23 +433,30 @@ def _maybe_downcast(self, blocks: list[Block], downcast=None) -> list[Block]:
431433
# but ATM it breaks too much existing code.
432434
# split and convert the blocks
433435

434-
return extend_blocks([blk.convert() for blk in blocks])
436+
return extend_blocks(
437+
[blk.convert(copy=not using_cow, using_cow=using_cow) for blk in blocks]
438+
)
435439

436440
if downcast is None:
437441
return blocks
438442

439-
return extend_blocks([b._downcast_2d(downcast) for b in blocks])
443+
return extend_blocks(
444+
[b._downcast_2d(downcast, using_cow=using_cow) for b in blocks]
445+
)
440446

441447
@final
442448
@maybe_split
443-
def _downcast_2d(self, dtype) -> list[Block]:
449+
def _downcast_2d(self, dtype, using_cow: bool = False) -> list[Block]:
444450
"""
445451
downcast specialized to 2D case post-validation.
446452
447453
Refactored to allow use of maybe_split.
448454
"""
449455
new_values = maybe_downcast_to_dtype(self.values, dtype=dtype)
450-
return [self.make_block(new_values)]
456+
refs = None
457+
if using_cow and new_values is self.values:
458+
refs = self.refs
459+
return [self.make_block(new_values, refs=refs)]
451460

452461
def convert(
453462
self,
@@ -1152,7 +1161,12 @@ def where(self, other, cond, _downcast: str | bool = "infer") -> list[Block]:
11521161
return [self.make_block(result)]
11531162

11541163
def fillna(
1155-
self, value, limit: int | None = None, inplace: bool = False, downcast=None
1164+
self,
1165+
value,
1166+
limit: int | None = None,
1167+
inplace: bool = False,
1168+
downcast=None,
1169+
using_cow: bool = False,
11561170
) -> list[Block]:
11571171
"""
11581172
fillna on the block with the value. If we fail, then convert to
@@ -1171,19 +1185,25 @@ def fillna(
11711185
if noop:
11721186
# we can't process the value, but nothing to do
11731187
if inplace:
1188+
if using_cow:
1189+
return [self.copy(deep=False)]
11741190
# Arbitrarily imposing the convention that we ignore downcast
11751191
# on no-op when inplace=True
11761192
return [self]
11771193
else:
11781194
# GH#45423 consistent downcasting on no-ops.
1179-
nb = self.copy()
1180-
nbs = nb._maybe_downcast([nb], downcast=downcast)
1195+
nb = self.copy(deep=not using_cow)
1196+
nbs = nb._maybe_downcast([nb], downcast=downcast, using_cow=using_cow)
11811197
return nbs
11821198

11831199
if limit is not None:
11841200
mask[mask.cumsum(self.ndim - 1) > limit] = False
11851201

11861202
if inplace:
1203+
if using_cow and self.refs.has_reference():
1204+
# TODO(CoW): If using_cow is implemented for putmask we can defer
1205+
# the copy
1206+
self = self.copy()
11871207
nbs = self.putmask(mask.T, value)
11881208
else:
11891209
# without _downcast, we would break
@@ -1194,7 +1214,10 @@ def fillna(
11941214
# makes a difference bc blk may have object dtype, which has
11951215
# different behavior in _maybe_downcast.
11961216
return extend_blocks(
1197-
[blk._maybe_downcast([blk], downcast=downcast) for blk in nbs]
1217+
[
1218+
blk._maybe_downcast([blk], downcast=downcast, using_cow=using_cow)
1219+
for blk in nbs
1220+
]
11981221
)
11991222

12001223
def interpolate(
@@ -1662,12 +1685,21 @@ class ExtensionBlock(libinternals.Block, EABackedBlock):
16621685
values: ExtensionArray
16631686

16641687
def fillna(
1665-
self, value, limit: int | None = None, inplace: bool = False, downcast=None
1688+
self,
1689+
value,
1690+
limit: int | None = None,
1691+
inplace: bool = False,
1692+
downcast=None,
1693+
using_cow: bool = False,
16661694
) -> list[Block]:
16671695
if is_interval_dtype(self.dtype):
16681696
# Block.fillna handles coercion (test_fillna_interval)
16691697
return super().fillna(
1670-
value=value, limit=limit, inplace=inplace, downcast=downcast
1698+
value=value,
1699+
limit=limit,
1700+
inplace=inplace,
1701+
downcast=downcast,
1702+
using_cow=using_cow,
16711703
)
16721704
new_values = self.values.fillna(value=value, method=None, limit=limit)
16731705
nb = self.make_block_same_class(new_values)

pandas/core/internals/managers.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -410,15 +410,14 @@ def fillna(self: T, value, limit, inplace: bool, downcast) -> T:
410410
if limit is not None:
411411
# Do this validation even if we go through one of the no-op paths
412412
limit = libalgos.validate_limit(None, limit=limit)
413-
if inplace:
414-
# TODO(CoW) can be optimized to only copy those blocks that have refs
415-
if using_copy_on_write() and any(
416-
not self._has_no_reference_block(i) for i in range(len(self.blocks))
417-
):
418-
self = self.copy()
419413

420414
return self.apply(
421-
"fillna", value=value, limit=limit, inplace=inplace, downcast=downcast
415+
"fillna",
416+
value=value,
417+
limit=limit,
418+
inplace=inplace,
419+
downcast=downcast,
420+
using_cow=using_copy_on_write(),
422421
)
423422

424423
def astype(self: T, dtype, copy: bool | None = False, errors: str = "raise") -> T:
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pandas import (
5+
DataFrame,
6+
Interval,
7+
Series,
8+
interval_range,
9+
)
10+
import pandas._testing as tm
11+
from pandas.tests.copy_view.util import get_array
12+
13+
14+
def test_fillna(using_copy_on_write):
15+
df = DataFrame({"a": [1.5, np.nan], "b": 1})
16+
df_orig = df.copy()
17+
18+
df2 = df.fillna(5.5)
19+
if using_copy_on_write:
20+
assert np.shares_memory(get_array(df, "b"), get_array(df2, "b"))
21+
else:
22+
assert not np.shares_memory(get_array(df, "b"), get_array(df2, "b"))
23+
24+
df2.iloc[0, 1] = 100
25+
tm.assert_frame_equal(df_orig, df)
26+
27+
28+
@pytest.mark.parametrize("downcast", [None, False])
29+
def test_fillna_inplace(using_copy_on_write, downcast):
30+
df = DataFrame({"a": [1.5, np.nan], "b": 1})
31+
arr_a = get_array(df, "a")
32+
arr_b = get_array(df, "b")
33+
34+
df.fillna(5.5, inplace=True, downcast=downcast)
35+
assert np.shares_memory(get_array(df, "a"), arr_a)
36+
assert np.shares_memory(get_array(df, "b"), arr_b)
37+
if using_copy_on_write:
38+
assert df._mgr._has_no_reference(0)
39+
assert df._mgr._has_no_reference(1)
40+
41+
42+
def test_fillna_inplace_reference(using_copy_on_write):
43+
df = DataFrame({"a": [1.5, np.nan], "b": 1})
44+
df_orig = df.copy()
45+
arr_a = get_array(df, "a")
46+
arr_b = get_array(df, "b")
47+
view = df[:]
48+
49+
df.fillna(5.5, inplace=True)
50+
if using_copy_on_write:
51+
assert not np.shares_memory(get_array(df, "a"), arr_a)
52+
assert np.shares_memory(get_array(df, "b"), arr_b)
53+
assert view._mgr._has_no_reference(0)
54+
assert df._mgr._has_no_reference(0)
55+
tm.assert_frame_equal(view, df_orig)
56+
else:
57+
assert np.shares_memory(get_array(df, "a"), arr_a)
58+
assert np.shares_memory(get_array(df, "b"), arr_b)
59+
expected = DataFrame({"a": [1.5, 5.5], "b": 1})
60+
tm.assert_frame_equal(df, expected)
61+
62+
63+
def test_fillna_interval_inplace_reference(using_copy_on_write):
64+
ser = Series(interval_range(start=0, end=5), name="a")
65+
ser.iloc[1] = np.nan
66+
67+
ser_orig = ser.copy()
68+
view = ser[:]
69+
ser.fillna(value=Interval(left=0, right=5), inplace=True)
70+
71+
if using_copy_on_write:
72+
assert not np.shares_memory(
73+
get_array(ser, "a").left.values, get_array(view, "a").left.values
74+
)
75+
tm.assert_series_equal(view, ser_orig)
76+
else:
77+
assert np.shares_memory(
78+
get_array(ser, "a").left.values, get_array(view, "a").left.values
79+
)

0 commit comments

Comments
 (0)