-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
BUG/API: make setitem-inplace preserve dtype when possible with PandasArray, IntegerArray, FloatingArray #39044
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 16 commits
6fffb02
160f3f7
de10708
ba98a99
84261a7
284f36a
2639b5c
b2aa366
7aeb2b5
715a602
b55155b
c72e566
0b9f343
cd6adbe
071ab1b
6ab44af
daacff8
450bf73
dba7c11
9258cbb
88309ab
1847209
6b8cc31
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,8 +32,11 @@ | |
from pandas.tests.extension import base | ||
|
||
|
||
def make_data(): | ||
return list(range(1, 9)) + [pd.NA] + list(range(10, 98)) + [pd.NA] + [99, 100] | ||
def make_data(with_nas: bool = True): | ||
if with_nas: | ||
return list(range(1, 9)) + [pd.NA] + list(range(10, 98)) + [pd.NA] + [99, 100] | ||
|
||
return list(range(1, 101)) | ||
|
||
|
||
@pytest.fixture( | ||
|
@@ -52,9 +55,10 @@ def dtype(request): | |
return request.param() | ||
|
||
|
||
@pytest.fixture | ||
def data(dtype): | ||
return pd.array(make_data(), dtype=dtype) | ||
@pytest.fixture(params=[True, False]) | ||
def data(dtype, request): | ||
with_nas = request.param | ||
return pd.array(make_data(with_nas), dtype=dtype) | ||
|
||
|
||
@pytest.fixture | ||
|
@@ -193,7 +197,20 @@ class TestGetitem(base.BaseGetitemTests): | |
|
||
|
||
class TestSetitem(base.BaseSetitemTests): | ||
pass | ||
def test_setitem_series(self, data, full_indexer): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you indicate here why it is overriding the base class? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. comment added |
||
# https://github.com/pandas-dev/pandas/issues/32395 | ||
ser = expected = pd.Series(data, name="data") | ||
result = pd.Series(index=ser.index, dtype=object, name="data") | ||
|
||
key = full_indexer(ser) | ||
result.loc[key] = ser | ||
|
||
if not data._mask.any(): | ||
# GH#38896 like we do with ndarray, we set the values inplace | ||
# but cast to the new numpy dtype | ||
expected = pd.Series(data.to_numpy(data.dtype.numpy_dtype), name="data") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we converting to the numpy dtype here? That's also not the original dtype? |
||
|
||
self.assert_series_equal(result, expected) | ||
|
||
|
||
class TestMissing(base.BaseMissingTests): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,10 +17,12 @@ | |
import pytest | ||
|
||
from pandas.core.dtypes.dtypes import ExtensionDtype, PandasDtype | ||
from pandas.core.dtypes.missing import infer_fill_value as infer_fill_value_orig | ||
|
||
import pandas as pd | ||
import pandas._testing as tm | ||
from pandas.core.arrays.numpy_ import PandasArray | ||
from pandas.core.arrays import PandasArray, StringArray | ||
from pandas.core.construction import extract_array | ||
|
||
from . import base | ||
|
||
|
@@ -30,6 +32,31 @@ def dtype(request): | |
return PandasDtype(np.dtype(request.param)) | ||
|
||
|
||
orig_setitem = pd.core.internals.Block.setitem | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you use monkeypatch instead? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this does use monkeypatch. the monkeypatched method calls the original method |
||
|
||
|
||
def setitem(self, indexer, value): | ||
# patch Block.setitem | ||
value = extract_array(value, extract_numpy=True) | ||
if isinstance(value, PandasArray) and not isinstance(value, StringArray): | ||
value = value.to_numpy() | ||
if self.ndim == 2 and value.ndim == 1: | ||
# TODO(EA2D): special case not needed with 2D EAs | ||
value = np.atleast_2d(value) | ||
|
||
return orig_setitem(self, indexer, value) | ||
|
||
|
||
def infer_fill_value(val, length: int): | ||
# GH#39044 we have to patch core.dtypes.missing.infer_fill_value | ||
# to unwrap PandasArray bc it won't recognize PandasArray with | ||
# is_extension_dtype | ||
if isinstance(val, PandasArray): | ||
val = val.to_numpy() | ||
|
||
return infer_fill_value_orig(val, length) | ||
|
||
|
||
@pytest.fixture | ||
def allow_in_pandas(monkeypatch): | ||
""" | ||
|
@@ -49,6 +76,8 @@ def allow_in_pandas(monkeypatch): | |
""" | ||
with monkeypatch.context() as m: | ||
m.setattr(PandasArray, "_typ", "extension") | ||
m.setattr(pd.core.indexing, "infer_fill_value", infer_fill_value) | ||
m.setattr(pd.core.internals.Block, "setitem", setitem) | ||
yield | ||
|
||
|
||
|
@@ -458,6 +487,42 @@ def test_setitem_slice(self, data, box_in_series): | |
def test_setitem_loc_iloc_slice(self, data): | ||
super().test_setitem_loc_iloc_slice(data) | ||
|
||
def test_setitem_with_expansion_dataframe_column(self, data, full_indexer, request): | ||
# https://github.com/pandas-dev/pandas/issues/32395 | ||
df = pd.DataFrame({"data": pd.Series(data)}) | ||
result = pd.DataFrame(index=df.index) | ||
|
||
key = full_indexer(df) | ||
result.loc[key, "data"] = df["data"]._values | ||
|
||
expected = pd.DataFrame({"data": data}) | ||
if data.dtype.numpy_dtype != object: | ||
# For PandasArray we expect to get unboxed to numpy | ||
expected = pd.DataFrame({"data": data.to_numpy()}) | ||
|
||
if isinstance(key, slice) and ( | ||
key == slice(None) and data.dtype.numpy_dtype != object | ||
): | ||
mark = pytest.mark.xfail( | ||
reason="This case goes through a different code path" | ||
) | ||
# Other cases go through Block.setitem | ||
request.node.add_marker(mark) | ||
|
||
self.assert_frame_equal(result, expected) | ||
|
||
def test_setitem_series(self, data, full_indexer): | ||
# https://github.com/pandas-dev/pandas/issues/32395 | ||
ser = pd.Series(data, name="data") | ||
result = pd.Series(index=ser.index, dtype=object, name="data") | ||
|
||
key = full_indexer(ser) | ||
result.loc[key] = ser | ||
|
||
# For PandasArray we expect to get unboxed to numpy | ||
expected = pd.Series(data.to_numpy(), name="data") | ||
self.assert_series_equal(result, expected) | ||
|
||
|
||
@skip_nested | ||
class TestParsing(BaseNumPyTests, base.BaseParsingTests): | ||
|
Uh oh!
There was an error while loading. Please reload this page.