12
12
13
13
14
14
def _get_array_list (arr , others ):
15
- if isinstance (others [0 ], (list , np .ndarray )):
15
+ if len ( others ) and isinstance (others [0 ], (list , np .ndarray )):
16
16
arrays = [arr ] + list (others )
17
17
else :
18
18
arrays = [arr , others ]
@@ -88,12 +88,15 @@ def _length_check(others):
88
88
return n
89
89
90
90
91
- def _na_map (f , arr , na_result = np .nan ):
91
+ def _na_map (f , arr , na_result = np .nan , dtype = object ):
92
92
# should really _check_ for NA
93
- return _map (f , arr , na_mask = True , na_value = na_result )
93
+ return _map (f , arr , na_mask = True , na_value = na_result , dtype = dtype )
94
94
95
95
96
- def _map (f , arr , na_mask = False , na_value = np .nan ):
96
+ def _map (f , arr , na_mask = False , na_value = np .nan , dtype = object ):
97
+ if not len (arr ):
98
+ return np .ndarray (0 , dtype = dtype )
99
+
97
100
if isinstance (arr , Series ):
98
101
arr = arr .values
99
102
if not isinstance (arr , np .ndarray ):
@@ -108,7 +111,7 @@ def g(x):
108
111
return f (x )
109
112
except (TypeError , AttributeError ):
110
113
return na_value
111
- return _map (g , arr )
114
+ return _map (g , arr , dtype = dtype )
112
115
if na_value is not np .nan :
113
116
np .putmask (result , mask , na_value )
114
117
if result .dtype == object :
@@ -146,7 +149,7 @@ def str_count(arr, pat, flags=0):
146
149
"""
147
150
regex = re .compile (pat , flags = flags )
148
151
f = lambda x : len (regex .findall (x ))
149
- return _na_map (f , arr )
152
+ return _na_map (f , arr , dtype = int )
150
153
151
154
152
155
def str_contains (arr , pat , case = True , flags = 0 , na = np .nan , regex = True ):
@@ -187,7 +190,7 @@ def str_contains(arr, pat, case=True, flags=0, na=np.nan, regex=True):
187
190
f = lambda x : bool (regex .search (x ))
188
191
else :
189
192
f = lambda x : pat in x
190
- return _na_map (f , arr , na )
193
+ return _na_map (f , arr , na , dtype = bool )
191
194
192
195
193
196
def str_startswith (arr , pat , na = np .nan ):
@@ -206,7 +209,7 @@ def str_startswith(arr, pat, na=np.nan):
206
209
startswith : array (boolean)
207
210
"""
208
211
f = lambda x : x .startswith (pat )
209
- return _na_map (f , arr , na )
212
+ return _na_map (f , arr , na , dtype = bool )
210
213
211
214
212
215
def str_endswith (arr , pat , na = np .nan ):
@@ -225,7 +228,7 @@ def str_endswith(arr, pat, na=np.nan):
225
228
endswith : array (boolean)
226
229
"""
227
230
f = lambda x : x .endswith (pat )
228
- return _na_map (f , arr , na )
231
+ return _na_map (f , arr , na , dtype = bool )
229
232
230
233
231
234
def str_lower (arr ):
@@ -375,6 +378,7 @@ def str_match(arr, pat, case=True, flags=0, na=np.nan, as_indexer=False):
375
378
# and is basically useless, so we will not warn.
376
379
377
380
if (not as_indexer ) and regex .groups > 0 :
381
+ dtype = object
378
382
def f (x ):
379
383
m = regex .match (x )
380
384
if m :
@@ -383,9 +387,10 @@ def f(x):
383
387
return []
384
388
else :
385
389
# This is the new behavior of str_match.
390
+ dtype = bool
386
391
f = lambda x : bool (regex .match (x ))
387
392
388
- return _na_map (f , arr , na )
393
+ return _na_map (f , arr , na , dtype = dtype )
389
394
390
395
391
396
def _get_single_group_name (rx ):
@@ -409,6 +414,9 @@ def str_extract(arr, pat, flags=0):
409
414
Returns
410
415
-------
411
416
extracted groups : Series (one group) or DataFrame (multiple groups)
417
+ Note that dtype of the result is always object, even when no match is
418
+ found and the result is a Series or DataFrame containing only NaN
419
+ values.
412
420
413
421
Examples
414
422
--------
@@ -461,13 +469,17 @@ def f(x):
461
469
if regex .groups == 1 :
462
470
result = Series ([f (val )[0 ] for val in arr ],
463
471
name = _get_single_group_name (regex ),
464
- index = arr .index )
472
+ index = arr .index , dtype = object )
465
473
else :
466
474
names = dict (zip (regex .groupindex .values (), regex .groupindex .keys ()))
467
475
columns = [names .get (1 + i , i ) for i in range (regex .groups )]
468
- result = DataFrame ([f (val ) for val in arr ],
469
- columns = columns ,
470
- index = arr .index )
476
+ if arr .empty :
477
+ result = DataFrame (columns = columns , dtype = object )
478
+ else :
479
+ result = DataFrame ([f (val ) for val in arr ],
480
+ columns = columns ,
481
+ index = arr .index ,
482
+ dtype = object )
471
483
return result
472
484
473
485
@@ -536,7 +548,7 @@ def str_len(arr):
536
548
-------
537
549
lengths : array
538
550
"""
539
- return _na_map (len , arr )
551
+ return _na_map (len , arr , dtype = int )
540
552
541
553
542
554
def str_findall (arr , pat , flags = 0 ):
0 commit comments