Skip to content

Commit 964186b

Browse files
committed
ENH: Add set_index to Series
1 parent 0976e12 commit 964186b

File tree

7 files changed

+799
-467
lines changed

7 files changed

+799
-467
lines changed

doc/source/whatsnew/v0.24.0.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ Other Enhancements
181181
The default compression for ``to_csv``, ``to_json``, and ``to_pickle`` methods has been updated to ``'infer'`` (:issue:`22004`).
182182
- :func:`to_timedelta` now supports iso-formated timedelta strings (:issue:`21877`)
183183
- :class:`Series` and :class:`DataFrame` now support :class:`Iterable` in constructor (:issue:`2193`)
184+
- :class:`Series` has gained the method :meth:`Series.set_index`, which works like its :class:`DataFrame` counterpart :meth:`DataFrame.set_index` (:issue:`21684`)
184185
- :class:`DatetimeIndex` gained :attr:`DatetimeIndex.timetz` attribute. Returns local time with timezone information. (:issue:`21358`)
185186
- :class:`Resampler` now is iterable like :class:`GroupBy` (:issue:`15314`).
186187
- :ref:`Series.resample` and :ref:`DataFrame.resample` have gained the :meth:`Resampler.quantile` (:issue:`15023`).

pandas/core/frame.py

Lines changed: 24 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -3843,6 +3843,10 @@ def set_index(self, keys, drop=True, append=False, inplace=False,
38433843
necessary. Setting to False will improve the performance of this
38443844
method
38453845
3846+
Returns
3847+
-------
3848+
reindexed : DataFrame if inplace is False, else None
3849+
38463850
Examples
38473851
--------
38483852
>>> df = pd.DataFrame({'month': [1, 4, 7, 10],
@@ -3883,73 +3887,30 @@ def set_index(self, keys, drop=True, append=False, inplace=False,
38833887
2 2014 4 40
38843888
3 2013 7 84
38853889
4 2014 10 31
3886-
3887-
Returns
3888-
-------
3889-
dataframe : DataFrame
38903890
"""
3891-
inplace = validate_bool_kwarg(inplace, 'inplace')
3891+
from pandas import Series
3892+
38923893
if not isinstance(keys, list):
38933894
keys = [keys]
38943895

3895-
if inplace:
3896-
frame = self
3897-
else:
3898-
frame = self.copy()
3899-
3900-
arrays = []
3901-
names = []
3902-
if append:
3903-
names = [x for x in self.index.names]
3904-
if isinstance(self.index, MultiIndex):
3905-
for i in range(self.index.nlevels):
3906-
arrays.append(self.index._get_level_values(i))
3907-
else:
3908-
arrays.append(self.index)
3909-
3910-
to_remove = []
3911-
for col in keys:
3912-
if isinstance(col, MultiIndex):
3913-
# append all but the last column so we don't have to modify
3914-
# the end of this loop
3915-
for n in range(col.nlevels - 1):
3916-
arrays.append(col._get_level_values(n))
3917-
3918-
level = col._get_level_values(col.nlevels - 1)
3919-
names.extend(col.names)
3920-
elif isinstance(col, Series):
3921-
level = col._values
3922-
names.append(col.name)
3923-
elif isinstance(col, Index):
3924-
level = col
3925-
names.append(col.name)
3926-
elif isinstance(col, (list, np.ndarray, Index)):
3927-
level = col
3928-
names.append(None)
3929-
else:
3930-
level = frame[col]._values
3931-
names.append(col)
3932-
if drop:
3933-
to_remove.append(col)
3934-
arrays.append(level)
3935-
3936-
index = ensure_index_from_sequences(arrays, names)
3937-
3938-
if verify_integrity and not index.is_unique:
3939-
duplicates = index[index.duplicated()].unique()
3940-
raise ValueError('Index has duplicate keys: {dup}'.format(
3941-
dup=duplicates))
3942-
3943-
for c in to_remove:
3944-
del frame[c]
3945-
3946-
# clear up memory usage
3947-
index._cleanup()
3948-
3949-
frame.index = index
3950-
3951-
if not inplace:
3952-
return frame
3896+
# collect elements from "keys" that are not allowed array types
3897+
col_labels = [x for x in keys
3898+
if not isinstance(x, (Series, Index, MultiIndex,
3899+
list, np.ndarray))]
3900+
if any(x not in self for x in col_labels):
3901+
# if there are any labels that are invalid, we raise a KeyError
3902+
missing = [x for x in col_labels if x not in self]
3903+
raise KeyError('{}'.format(missing))
3904+
elif len(set(col_labels)) < len(col_labels):
3905+
# if all are valid labels, but there are duplicates
3906+
dup = Series(col_labels)
3907+
dup = list(dup.loc[dup.duplicated()])
3908+
raise ValueError('Passed duplicate column names '
3909+
'to keys: {dup}'.format(dup=dup))
3910+
vi = verify_integrity
3911+
return super(DataFrame, self).set_index(keys=keys, drop=drop,
3912+
append=append, inplace=inplace,
3913+
verify_integrity=vi)
39533914

39543915
def reset_index(self, level=None, drop=False, inplace=False, col_level=0,
39553916
col_fill=''):

pandas/core/generic.py

Lines changed: 131 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,13 @@
3232
from pandas.core.dtypes.cast import maybe_promote, maybe_upcast_putmask
3333
from pandas.core.dtypes.inference import is_hashable
3434
from pandas.core.dtypes.missing import isna, notna
35-
from pandas.core.dtypes.generic import ABCSeries, ABCPanel, ABCDataFrame
35+
from pandas.core.dtypes.generic import (ABCIndexClass, ABCMultiIndex, ABCPanel,
36+
ABCSeries, ABCDataFrame)
3637

3738
from pandas.core.base import PandasObject, SelectionMixin
38-
from pandas.core.index import (Index, MultiIndex, ensure_index,
39-
InvalidIndexError, RangeIndex)
39+
from pandas.core.index import (Index, MultiIndex,
40+
InvalidIndexError, RangeIndex,
41+
ensure_index, ensure_index_from_sequences)
4042
import pandas.core.indexing as indexing
4143
from pandas.core.indexes.datetimes import DatetimeIndex
4244
from pandas.core.indexes.period import PeriodIndex, Period
@@ -663,6 +665,132 @@ def _set_axis(self, axis, labels):
663665
y : same as input
664666
"""
665667

668+
def set_index(self, keys, drop=True, append=False, inplace=False,
669+
verify_integrity=False):
670+
"""
671+
Set the Series/DataFrame index (row labels) using one or more given
672+
arrays (or column labels in case of DataFrame).
673+
By default yields a new object.
674+
675+
Parameters
676+
----------
677+
keys : column label or list of column labels / arrays. For Series case,
678+
only array or list of arrays is allowed.
679+
drop : boolean, default True
680+
Delete columns to be used as the new index (only for DataFrame).
681+
append : boolean, default False
682+
Whether to append columns to existing index
683+
inplace : boolean, default False
684+
Modify the Series/DataFrame in place (do not create a new object)
685+
verify_integrity : boolean, default False
686+
Check the new index for duplicates. Otherwise defer the check until
687+
necessary. Setting to False will improve the performance of this
688+
method
689+
690+
Returns
691+
-------
692+
reindexed : Series/DataFrame if inplace is False, else None
693+
694+
Examples
695+
--------
696+
>>> df = pd.DataFrame({'month': [1, 4, 7, 10],
697+
... 'year': [2012, 2014, 2013, 2014],
698+
... 'sale':[55, 40, 84, 31]})
699+
month sale year
700+
0 1 55 2012
701+
1 4 40 2014
702+
2 7 84 2013
703+
3 10 31 2014
704+
705+
Set the index to become the 'month' column:
706+
707+
>>> df.set_index('month')
708+
sale year
709+
month
710+
1 55 2012
711+
4 40 2014
712+
7 84 2013
713+
10 31 2014
714+
715+
Create a multi-index using columns 'year' and 'month':
716+
717+
>>> df.set_index(['year', 'month'])
718+
sale
719+
year month
720+
2012 1 55
721+
2014 4 40
722+
2013 7 84
723+
2014 10 31
724+
725+
Create a multi-index using a set of values and a column:
726+
727+
>>> df.set_index([[1, 2, 3, 4], 'year'])
728+
month sale
729+
year
730+
1 2012 1 55
731+
2 2014 4 40
732+
3 2013 7 84
733+
4 2014 10 31
734+
"""
735+
inplace = validate_bool_kwarg(inplace, 'inplace')
736+
if inplace:
737+
obj = self
738+
else:
739+
obj = self.copy()
740+
741+
arrays = []
742+
names = []
743+
if append:
744+
names = [x for x in self.index.names]
745+
if isinstance(self.index, ABCMultiIndex):
746+
for i in range(self.index.nlevels):
747+
arrays.append(self.index._get_level_values(i))
748+
else:
749+
arrays.append(self.index)
750+
751+
to_remove = []
752+
for col in keys:
753+
if isinstance(col, ABCMultiIndex):
754+
for n in range(col.nlevels):
755+
arrays.append(col._get_level_values(n))
756+
names.extend(col.names)
757+
elif isinstance(col, ABCIndexClass):
758+
# Index but not MultiIndex (treated above)
759+
arrays.append(col)
760+
names.append(col.name)
761+
elif isinstance(col, ABCSeries):
762+
arrays.append(col._values)
763+
names.append(col.name)
764+
elif isinstance(col, (list, np.ndarray)):
765+
arrays.append(col)
766+
names.append(None)
767+
# from here, col can only be a column label (and self a DataFrame);
768+
# see checks in Series.set_index and DataFrame.set_index
769+
else:
770+
arrays.append(obj[col]._values)
771+
names.append(col)
772+
if drop:
773+
to_remove.append(col)
774+
775+
index = ensure_index_from_sequences(arrays, names)
776+
777+
if verify_integrity and not index.is_unique:
778+
duplicates = list(index[index.duplicated()])
779+
raise ValueError('Index has duplicate keys: {dup}'.format(
780+
dup=duplicates))
781+
782+
# use set to handle duplicate column names gracefully in case of drop
783+
for c in set(to_remove):
784+
del obj[c]
785+
786+
# clear up memory usage
787+
index._cleanup()
788+
789+
obj.index = index
790+
791+
if not inplace:
792+
return obj
793+
666794
@Appender(_shared_docs['transpose'] % _shared_doc_kwargs)
667795
def transpose(self, *args, **kwargs):
668796

pandas/core/series.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535
_is_unorderable_exception,
3636
ensure_platform_int,
3737
pandas_dtype)
38-
from pandas.core.dtypes.generic import (
39-
ABCSparseArray, ABCDataFrame, ABCIndexClass)
38+
from pandas.core.dtypes.generic import (ABCDataFrame, ABCIndexClass,
39+
ABCSeries, ABCSparseArray)
4040
from pandas.core.dtypes.cast import (
4141
maybe_upcast, infer_dtype_from_scalar,
4242
maybe_convert_platform,
@@ -1094,6 +1094,86 @@ def _set_value(self, label, value, takeable=False):
10941094
return self
10951095
_set_value.__doc__ = set_value.__doc__
10961096

1097+
def set_index(self, arrays, append=False, inplace=False,
1098+
verify_integrity=False):
1099+
"""
1100+
Set the Series index (row labels) using one or more columns.
1101+
By default yields a new object.
1102+
1103+
Parameters
1104+
----------
1105+
arrays : array or list of arrays
1106+
Either a Series, Index, MultiIndex, list, np.ndarray or a list
1107+
containing only Series, Index, MultiIndex, list, np.ndarray
1108+
append : boolean, default False
1109+
Whether to append columns to existing index
1110+
inplace : boolean, default False
1111+
Modify the Series in place (do not create a new object)
1112+
verify_integrity : boolean, default False
1113+
Check the new index for duplicates. Otherwise defer the check until
1114+
necessary. Setting to False will improve the performance of this
1115+
method
1116+
1117+
Returns
1118+
-------
1119+
reindexed : Series if inplace is False, else None
1120+
1121+
Examples
1122+
--------
1123+
>>> s = pd.Series(range(3))
1124+
0 10
1125+
1 11
1126+
2 12
1127+
dtype: int64
1128+
1129+
Set the index to become `['a', 'b', 'c']`:
1130+
1131+
>>> s.set_index(['a', 'b', 'c'])
1132+
a 10
1133+
b 11
1134+
c 12
1135+
dtype: int64
1136+
1137+
Create a multi-index by appending to the existing index:
1138+
1139+
>>> s.set_index(['a', 'b', 'c'], append=True)
1140+
0 a 10
1141+
1 b 11
1142+
2 c 12
1143+
dtype: int64
1144+
1145+
Create a multi-index by passing a list of arrays:
1146+
1147+
>>> t = s.set_index([['a', 'b', 'c'], ['I', 'II', 'III']]) ** 2
1148+
>>> t
1149+
a I 100
1150+
b II 121
1151+
c III 144
1152+
dtype: int64
1153+
1154+
Apply index from another object (of the same length!):
1155+
1156+
>>> s.set_index(t.index)
1157+
a I 10
1158+
b II 11
1159+
c III 12
1160+
dtype: int64
1161+
"""
1162+
if not isinstance(arrays, list):
1163+
arrays = [arrays]
1164+
elif all(is_scalar(x) for x in arrays):
1165+
arrays = [arrays]
1166+
1167+
if any(not isinstance(x, (ABCSeries, ABCIndexClass, list, np.ndarray))
1168+
for x in arrays):
1169+
raise TypeError('arrays must be Series, Index, MultiIndex, list, '
1170+
'np.ndarray or list containing only Series, '
1171+
'Index, MultiIndex, list, np.ndarray')
1172+
1173+
return super(Series, self).set_index(keys=arrays, drop=False,
1174+
append=append, inplace=inplace,
1175+
verify_integrity=verify_integrity)
1176+
10971177
def reset_index(self, level=None, drop=False, name=None, inplace=False):
10981178
"""
10991179
Generate a new DataFrame or Series with the index reset.

pandas/tests/frame/common.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,15 @@ def simple(self):
103103
return pd.DataFrame(arr, columns=['one', 'two', 'three'],
104104
index=['a', 'b', 'c'])
105105

106+
@cache_readonly
107+
def dummy(self):
108+
df = pd.DataFrame({'A': ['foo', 'foo', 'foo', 'bar', 'bar'],
109+
'B': ['one', 'two', 'three', 'one', 'two'],
110+
'C': ['a', 'b', 'c', 'd', 'e'],
111+
'D': np.random.randn(5),
112+
'E': np.random.randn(5)})
113+
return df
114+
106115
# self.ts3 = tm.makeTimeSeries()[-5:]
107116
# self.ts4 = tm.makeTimeSeries()[1:-1]
108117

0 commit comments

Comments
 (0)