Skip to content

Commit f24f2e8

Browse files
wabuhayd
authored andcommitted
BUG StringMethods on empty series (GH7242)
- all StringMethods are tested and work on empty seires - moreover extract always returns dtype==object, even when no match is found
1 parent a477a2e commit f24f2e8

File tree

4 files changed

+78
-17
lines changed

4 files changed

+78
-17
lines changed

doc/source/basics.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,6 +1165,9 @@ Thus, a Series of messy strings can be "converted" into a
11651165
like-indexed Series or DataFrame of cleaned-up or more useful strings,
11661166
without necessitating ``get()`` to access tuples or ``re.match`` objects.
11671167

1168+
The results dtype always is object, even if no match is found and the result
1169+
only contains ``NaN``.
1170+
11681171
Named groups like
11691172

11701173
.. ipython:: python

doc/source/v0.14.1.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ API changes
3030
- Openpyxl now raises a ValueError on construction of the openpyxl writer
3131
instead of warning on pandas import (:issue:`7284`).
3232

33+
- For ``StringMethods.extract``, when no match is found, the result - only
34+
containing ``NaN`` values - now also has ``dtype=object`` instead of
35+
``float`` (:issue:`7242`)
36+
3337
.. _whatsnew_0141.prior_deprecations:
3438

3539
Prior Version Deprecations/Changes
@@ -90,3 +94,4 @@ Bug Fixes
9094
- Bug in broadcasting with ``.div``, integer dtypes and divide-by-zero (:issue:`7325`)
9195
- Bug in ``CustomBusinessDay.apply`` raiases ``NameError`` when ``np.datetime64`` object is passed (:issue:`7196`)
9296
- Bug in ``MultiIndex.append``, ``concat`` and ``pivot_table`` don't preserve timezone (:issue:`6606`)
97+
- Bug all ``StringMethods`` now work on empty Series (:issue:`7242`)

pandas/core/strings.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
def _get_array_list(arr, others):
15-
if isinstance(others[0], (list, np.ndarray)):
15+
if len(others) and isinstance(others[0], (list, np.ndarray)):
1616
arrays = [arr] + list(others)
1717
else:
1818
arrays = [arr, others]
@@ -88,12 +88,15 @@ def _length_check(others):
8888
return n
8989

9090

91-
def _na_map(f, arr, na_result=np.nan):
91+
def _na_map(f, arr, na_result=np.nan, dtype=object):
9292
# should really _check_ for NA
93-
return _map(f, arr, na_mask=True, na_value=na_result)
93+
return _map(f, arr, na_mask=True, na_value=na_result, dtype=dtype)
9494

9595

96-
def _map(f, arr, na_mask=False, na_value=np.nan):
96+
def _map(f, arr, na_mask=False, na_value=np.nan, dtype=object):
97+
if not len(arr):
98+
return np.ndarray(0, dtype=dtype)
99+
97100
if isinstance(arr, Series):
98101
arr = arr.values
99102
if not isinstance(arr, np.ndarray):
@@ -108,7 +111,7 @@ def g(x):
108111
return f(x)
109112
except (TypeError, AttributeError):
110113
return na_value
111-
return _map(g, arr)
114+
return _map(g, arr, dtype=dtype)
112115
if na_value is not np.nan:
113116
np.putmask(result, mask, na_value)
114117
if result.dtype == object:
@@ -146,7 +149,7 @@ def str_count(arr, pat, flags=0):
146149
"""
147150
regex = re.compile(pat, flags=flags)
148151
f = lambda x: len(regex.findall(x))
149-
return _na_map(f, arr)
152+
return _na_map(f, arr, dtype=int)
150153

151154

152155
def str_contains(arr, pat, case=True, flags=0, na=np.nan, regex=True):
@@ -187,7 +190,7 @@ def str_contains(arr, pat, case=True, flags=0, na=np.nan, regex=True):
187190
f = lambda x: bool(regex.search(x))
188191
else:
189192
f = lambda x: pat in x
190-
return _na_map(f, arr, na)
193+
return _na_map(f, arr, na, dtype=bool)
191194

192195

193196
def str_startswith(arr, pat, na=np.nan):
@@ -206,7 +209,7 @@ def str_startswith(arr, pat, na=np.nan):
206209
startswith : array (boolean)
207210
"""
208211
f = lambda x: x.startswith(pat)
209-
return _na_map(f, arr, na)
212+
return _na_map(f, arr, na, dtype=bool)
210213

211214

212215
def str_endswith(arr, pat, na=np.nan):
@@ -225,7 +228,7 @@ def str_endswith(arr, pat, na=np.nan):
225228
endswith : array (boolean)
226229
"""
227230
f = lambda x: x.endswith(pat)
228-
return _na_map(f, arr, na)
231+
return _na_map(f, arr, na, dtype=bool)
229232

230233

231234
def str_lower(arr):
@@ -375,6 +378,7 @@ def str_match(arr, pat, case=True, flags=0, na=np.nan, as_indexer=False):
375378
# and is basically useless, so we will not warn.
376379

377380
if (not as_indexer) and regex.groups > 0:
381+
dtype = object
378382
def f(x):
379383
m = regex.match(x)
380384
if m:
@@ -383,9 +387,10 @@ def f(x):
383387
return []
384388
else:
385389
# This is the new behavior of str_match.
390+
dtype = bool
386391
f = lambda x: bool(regex.match(x))
387392

388-
return _na_map(f, arr, na)
393+
return _na_map(f, arr, na, dtype=dtype)
389394

390395

391396
def _get_single_group_name(rx):
@@ -409,6 +414,9 @@ def str_extract(arr, pat, flags=0):
409414
Returns
410415
-------
411416
extracted groups : Series (one group) or DataFrame (multiple groups)
417+
Note that dtype of the result is always object, even when no match is
418+
found and the result is a Series or DataFrame containing only NaN
419+
values.
412420
413421
Examples
414422
--------
@@ -461,13 +469,17 @@ def f(x):
461469
if regex.groups == 1:
462470
result = Series([f(val)[0] for val in arr],
463471
name=_get_single_group_name(regex),
464-
index=arr.index)
472+
index=arr.index, dtype=object)
465473
else:
466474
names = dict(zip(regex.groupindex.values(), regex.groupindex.keys()))
467475
columns = [names.get(1 + i, i) for i in range(regex.groups)]
468-
result = DataFrame([f(val) for val in arr],
469-
columns=columns,
470-
index=arr.index)
476+
if arr.empty:
477+
result = DataFrame(columns=columns, dtype=object)
478+
else:
479+
result = DataFrame([f(val) for val in arr],
480+
columns=columns,
481+
index=arr.index,
482+
dtype=object)
471483
return result
472484

473485

@@ -536,7 +548,7 @@ def str_len(arr):
536548
-------
537549
lengths : array
538550
"""
539-
return _na_map(len, arr)
551+
return _na_map(len, arr, dtype=int)
540552

541553

542554
def str_findall(arr, pat, flags=0):

pandas/tests/test_strings.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -505,12 +505,12 @@ def test_extract(self):
505505

506506
# one group, no matches
507507
result = s.str.extract('(_)')
508-
exp = Series([NA, NA, NA])
508+
exp = Series([NA, NA, NA], dtype=object)
509509
tm.assert_series_equal(result, exp)
510510

511511
# two groups, no matches
512512
result = s.str.extract('(_)(_)')
513-
exp = DataFrame([[NA, NA], [NA, NA], [NA, NA]])
513+
exp = DataFrame([[NA, NA], [NA, NA], [NA, NA]], dtype=object)
514514
tm.assert_frame_equal(result, exp)
515515

516516
# one group, some matches
@@ -585,6 +585,47 @@ def test_extract_single_series_name_is_preserved(self):
585585
tm.assert_series_equal(r, e)
586586
self.assertEqual(r.name, e.name)
587587

588+
def test_empty_str_methods(self):
589+
empty_str = empty = Series(dtype=str)
590+
empty_int = Series(dtype=int)
591+
empty_bool = Series(dtype=bool)
592+
empty_list = Series(dtype=list)
593+
empty_bytes = Series(dtype=object)
594+
595+
# GH7241
596+
# (extract) on empty series
597+
598+
tm.assert_series_equal(empty_str, empty.str.cat(empty))
599+
tm.assert_equal('', empty.str.cat())
600+
tm.assert_series_equal(empty_str, empty.str.title())
601+
tm.assert_series_equal(empty_int, empty.str.count('a'))
602+
tm.assert_series_equal(empty_bool, empty.str.contains('a'))
603+
tm.assert_series_equal(empty_bool, empty.str.startswith('a'))
604+
tm.assert_series_equal(empty_bool, empty.str.endswith('a'))
605+
tm.assert_series_equal(empty_str, empty.str.lower())
606+
tm.assert_series_equal(empty_str, empty.str.upper())
607+
tm.assert_series_equal(empty_str, empty.str.replace('a','b'))
608+
tm.assert_series_equal(empty_str, empty.str.repeat(3))
609+
tm.assert_series_equal(empty_bool, empty.str.match('^a'))
610+
tm.assert_series_equal(empty_str, empty.str.extract('()'))
611+
tm.assert_frame_equal(DataFrame(columns=[0,1], dtype=str), empty.str.extract('()()'))
612+
tm.assert_frame_equal(DataFrame(dtype=str), empty.str.get_dummies())
613+
tm.assert_series_equal(empty_str, empty_list.str.join(''))
614+
tm.assert_series_equal(empty_int, empty.str.len())
615+
tm.assert_series_equal(empty_list, empty_list.str.findall('a'))
616+
tm.assert_series_equal(empty_str, empty.str.pad(42))
617+
tm.assert_series_equal(empty_str, empty.str.center(42))
618+
tm.assert_series_equal(empty_list, empty.str.split('a'))
619+
tm.assert_series_equal(empty_str, empty.str.slice(stop=1))
620+
tm.assert_series_equal(empty_str, empty.str.strip())
621+
tm.assert_series_equal(empty_str, empty.str.lstrip())
622+
tm.assert_series_equal(empty_str, empty.str.rstrip())
623+
tm.assert_series_equal(empty_str, empty.str.rstrip())
624+
tm.assert_series_equal(empty_str, empty.str.wrap(42))
625+
tm.assert_series_equal(empty_str, empty.str.get(0))
626+
tm.assert_series_equal(empty_str, empty_bytes.str.decode('ascii'))
627+
tm.assert_series_equal(empty_bytes, empty.str.encode('ascii'))
628+
588629
def test_get_dummies(self):
589630
s = Series(['a|b', 'a|c', np.nan])
590631
result = s.str.get_dummies('|')

0 commit comments

Comments
 (0)