From 6969c9555221b71f9316fd4ecbdc2a410237f96f Mon Sep 17 00:00:00 2001 From: sinhrks Date: Sat, 27 Feb 2016 14:03:55 +0900 Subject: [PATCH] CLN: cleanup _wrap_result --- pandas/core/strings.py | 88 +++++++++++++++++++----------------------- 1 file changed, 39 insertions(+), 49 deletions(-) diff --git a/pandas/core/strings.py b/pandas/core/strings.py index c1ab46956c25f..a7ed1ba0c0be0 100644 --- a/pandas/core/strings.py +++ b/pandas/core/strings.py @@ -604,7 +604,7 @@ def str_extract(arr, pat, flags=0, expand=None): return _str_extract_frame(arr._orig, pat, flags=flags) else: result, name = _str_extract_noexpand(arr._data, pat, flags=flags) - return arr._wrap_result(result, name=name) + return arr._wrap_result(result, name=name, expand=expand) def str_extractall(arr, pat, flags=0): @@ -1292,7 +1292,10 @@ def __iter__(self): i += 1 g = self.get(i) - def _wrap_result(self, result, use_codes=True, name=None): + def _wrap_result(self, result, use_codes=True, + name=None, expand=None): + + from pandas.core.index import Index, MultiIndex # for category, we do the stuff on the categories, so blow it up # to the full series again @@ -1302,48 +1305,42 @@ def _wrap_result(self, result, use_codes=True, name=None): if use_codes and self._is_categorical: result = take_1d(result, self._orig.cat.codes) - # leave as it is to keep extract and get_dummies results - # can be merged to _wrap_result_expand in v0.17 - from pandas.core.series import Series - from pandas.core.frame import DataFrame - from pandas.core.index import Index - - if not hasattr(result, 'ndim'): + if not hasattr(result, 'ndim') or not hasattr(result, 'dtype'): return result + assert result.ndim < 3 - if result.ndim == 1: - # Wait until we are sure result is a Series or Index before - # checking attributes (GH 12180) - name = name or getattr(result, 'name', None) or self._orig.name - if isinstance(self._orig, Index): - # if result is a boolean np.array, return the np.array - # instead of wrapping it into a boolean Index (GH 8875) - if is_bool_dtype(result): - return result - return Index(result, name=name) - return Series(result, index=self._orig.index, name=name) - else: - assert result.ndim < 3 - return DataFrame(result, index=self._orig.index) + if expand is None: + # infer from ndim if expand is not specified + expand = False if result.ndim == 1 else True + + elif expand is True and not isinstance(self._orig, Index): + # required when expand=True is explicitly specified + # not needed when infered + + def cons_row(x): + if is_list_like(x): + return x + else: + return [x] + + result = [cons_row(x) for x in result] - def _wrap_result_expand(self, result, expand=False): if not isinstance(expand, bool): raise ValueError("expand must be True or False") - # for category, we do the stuff on the categories, so blow it up - # to the full series again - if self._is_categorical: - result = take_1d(result, self._orig.cat.codes) - - from pandas.core.index import Index, MultiIndex - if not hasattr(result, 'ndim'): - return result + if name is None: + name = getattr(result, 'name', None) + if name is None: + # do not use logical or, _orig may be a DataFrame + # which has "name" column + name = self._orig.name + # Wait until we are sure result is a Series or Index before + # checking attributes (GH 12180) if isinstance(self._orig, Index): - name = getattr(result, 'name', None) # if result is a boolean np.array, return the np.array # instead of wrapping it into a boolean Index (GH 8875) - if hasattr(result, 'dtype') and is_bool_dtype(result): + if is_bool_dtype(result): return result if expand: @@ -1354,18 +1351,10 @@ def _wrap_result_expand(self, result, expand=False): else: index = self._orig.index if expand: - - def cons_row(x): - if is_list_like(x): - return x - else: - return [x] - cons = self._orig._constructor_expanddim - data = [cons_row(x) for x in result] - return cons(data, index=index) + return cons(result, index=index) else: - name = getattr(result, 'name', None) + # Must a Series cons = self._orig._constructor return cons(result, name=name, index=index) @@ -1380,12 +1369,12 @@ def cat(self, others=None, sep=None, na_rep=None): @copy(str_split) def split(self, pat=None, n=-1, expand=False): result = str_split(self._data, pat, n=n) - return self._wrap_result_expand(result, expand=expand) + return self._wrap_result(result, expand=expand) @copy(str_rsplit) def rsplit(self, pat=None, n=-1, expand=False): result = str_rsplit(self._data, pat, n=n) - return self._wrap_result_expand(result, expand=expand) + return self._wrap_result(result, expand=expand) _shared_docs['str_partition'] = (""" Split the string at the %(side)s occurrence of `sep`, and return 3 elements @@ -1440,7 +1429,7 @@ def rsplit(self, pat=None, n=-1, expand=False): def partition(self, pat=' ', expand=True): f = lambda x: x.partition(pat) result = _na_map(f, self._data) - return self._wrap_result_expand(result, expand=expand) + return self._wrap_result(result, expand=expand) @Appender(_shared_docs['str_partition'] % { 'side': 'last', @@ -1451,7 +1440,7 @@ def partition(self, pat=' ', expand=True): def rpartition(self, pat=' ', expand=True): f = lambda x: x.rpartition(pat) result = _na_map(f, self._data) - return self._wrap_result_expand(result, expand=expand) + return self._wrap_result(result, expand=expand) @copy(str_get) def get(self, i): @@ -1597,7 +1586,8 @@ def get_dummies(self, sep='|'): # methods available for making the dummies... data = self._orig.astype(str) if self._is_categorical else self._data result = str_get_dummies(data, sep) - return self._wrap_result(result, use_codes=(not self._is_categorical)) + return self._wrap_result(result, use_codes=(not self._is_categorical), + expand=True) @copy(str_translate) def translate(self, table, deletechars=None):