Skip to content

Commit 62ddb7e

Browse files
committed
BUG: Accept dict or Series in fillna for categorical Series
1 parent c2590b3 commit 62ddb7e

File tree

2 files changed

+58
-9
lines changed

2 files changed

+58
-9
lines changed

pandas/core/categorical.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,16 +1665,26 @@ def fillna(self, value=None, method=None, limit=None):
16651665

16661666
else:
16671667

1668-
if not isna(value) and value not in self.categories:
1669-
raise ValueError("fill value must be in categories")
1668+
if isinstance(value, ABCSeries):
1669+
if not value[~value.isin(self.categories)].isna().all():
1670+
raise ValueError("fill value must be in categories")
16701671

1671-
mask = values == -1
1672-
if mask.any():
1673-
values = values.copy()
1674-
if isna(value):
1675-
values[mask] = -1
1676-
else:
1677-
values[mask] = self.categories.get_loc(value)
1672+
values_codes = _get_codes_for_values(value, self.categories)
1673+
indexer = np.where(values_codes != -1)
1674+
values[indexer] = values_codes[values_codes != -1]
1675+
1676+
# Scalar value
1677+
else:
1678+
if not isna(value) and value not in self.categories:
1679+
raise ValueError("fill value must be in categories")
1680+
1681+
mask = values == -1
1682+
if mask.any():
1683+
values = values.copy()
1684+
if isna(value):
1685+
values[mask] = -1
1686+
else:
1687+
values[mask] = self.categories.get_loc(value)
16781688

16791689
return self._constructor(values, categories=self.categories,
16801690
ordered=self.ordered, fastpath=True)

pandas/tests/test_categorical.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4569,6 +4569,45 @@ def f():
45694569
df = DataFrame({'a': Categorical(idx)})
45704570
tm.assert_frame_equal(df.fillna(value=NaT), df)
45714571

4572+
@pytest.mark.parametrize('fill_value expected_output', [
4573+
('a', ['a', 'a', 'b', 'a', 'a']),
4574+
({1: 'a', 3: 'b', 4: 'b'}, ['a', 'a', 'b', 'b', 'b']),
4575+
({1: 'a'}, ['a', 'a', 'b', np.nan, np.nan]),
4576+
({1: 'a', 3: 'b'}, ['a', 'a', 'b', 'b', np.nan]),
4577+
(pd.Series('a'), ['a', np.nan, 'b', np.nan, np.nan]),
4578+
(pd.Series('a', index=[1]), ['a', 'a', 'b', np.nan, np.nan]),
4579+
(pd.Series({1: 'a', 3: 'b'}), ['a', 'a', 'b', 'b', np.nan]),
4580+
(pd.Series(['a', 'b'], index=[3, 4]))
4581+
])
4582+
def fillna_series_categorical(self, fill_value, expected_output):
4583+
# GH 17033
4584+
# Test fillna for a Categorical series
4585+
data = ['a', np.nan, 'b', np.nan, np.nan]
4586+
s = pd.Series(pd.Categorical(data, categories=['a', 'b']))
4587+
exp = pd.Series(pd.Categorical(expected_output, categories=['a', 'b']))
4588+
tm.assert_series_equal(s.fillna(fill_value), exp)
4589+
4590+
def fillna_series_categorical_errormsg(self):
4591+
data = ['a', np.nan, 'b', np.nan, np.nan]
4592+
s = pd.Series(pd.Categorical(data, categories=['a', 'b']))
4593+
4594+
with tm.assert_raises_regex(ValueError,
4595+
"fill value must be in categories"):
4596+
s.fillna('d')
4597+
4598+
with tm.assert_raises_regex(ValueError,
4599+
"fill value must be in categories"):
4600+
s.fillna(pd.Series('d'))
4601+
4602+
with tm.assert_raises_regex(ValueError,
4603+
"fill value must be in categories"):
4604+
s.fillna({1: 'd', 3: 'a'})
4605+
4606+
with tm.assert_raises_regex(TypeError,
4607+
'"value" parameter must be a scalar or '
4608+
'dict but you passed a "list"'):
4609+
s.fillna(['a', 'b'])
4610+
45724611
def test_astype_to_other(self):
45734612

45744613
s = self.cat['value_group']

0 commit comments

Comments
 (0)