diff --git a/doc/source/whatsnew/v0.17.0.txt b/doc/source/whatsnew/v0.17.0.txt index dae1342c3cd76..bebce2d3e2d87 100644 --- a/doc/source/whatsnew/v0.17.0.txt +++ b/doc/source/whatsnew/v0.17.0.txt @@ -67,6 +67,7 @@ Bug Fixes - Bug in ``mean()`` where integer dtypes can overflow (:issue:`10172`) - Bug where Panel.from_dict does not set dtype when specified (:issue:`10058`) +- Bug in ``Index.union`` raises ``AttributeError`` when passing array-likes. (:issue:`10149`) - Bug in ``Timestamp``'s' ``microsecond``, ``quarter``, ``dayofyear``, ``week`` and ``daysinmonth`` properties return ``np.int`` type, not built-in ``int``. (:issue:`10050`) - Bug in ``NaT`` raises ``AttributeError`` when accessing to ``daysinmonth``, ``dayofweek`` properties. (:issue:`10096`) @@ -91,3 +92,4 @@ Bug Fixes - Bug where infer_freq infers timerule (WOM-5XXX) unsupported by to_offset (:issue:`9425`) + diff --git a/pandas/core/index.py b/pandas/core/index.py index de30fee4009f4..2bd96fcec2e42 100644 --- a/pandas/core/index.py +++ b/pandas/core/index.py @@ -580,8 +580,18 @@ def to_datetime(self, dayfirst=False): return DatetimeIndex(self.values) def _assert_can_do_setop(self, other): + if not com.is_list_like(other): + raise TypeError('Input must be Index or array-like') return True + def _convert_can_do_setop(self, other): + if not isinstance(other, Index): + other = Index(other, name=self.name) + result_name = self.name + else: + result_name = self.name if self.name == other.name else None + return other, result_name + @property def nlevels(self): return 1 @@ -1364,16 +1374,14 @@ def union(self, other): ------- union : Index """ - if not hasattr(other, '__iter__'): - raise TypeError('Input must be iterable.') + self._assert_can_do_setop(other) + other = _ensure_index(other) if len(other) == 0 or self.equals(other): return self if len(self) == 0: - return _ensure_index(other) - - self._assert_can_do_setop(other) + return other if not is_dtype_equal(self.dtype,other.dtype): this = self.astype('O') @@ -1439,11 +1447,7 @@ def intersection(self, other): ------- intersection : Index """ - if not hasattr(other, '__iter__'): - raise TypeError('Input must be iterable!') - self._assert_can_do_setop(other) - other = _ensure_index(other) if self.equals(other): @@ -1492,18 +1496,12 @@ def difference(self, other): >>> index.difference(index2) """ - - if not hasattr(other, '__iter__'): - raise TypeError('Input must be iterable!') + self._assert_can_do_setop(other) if self.equals(other): return Index([], name=self.name) - if not isinstance(other, Index): - other = np.asarray(other) - result_name = self.name - else: - result_name = self.name if self.name == other.name else None + other, result_name = self._convert_can_do_setop(other) theDiff = sorted(set(self) - set(other)) return Index(theDiff, name=result_name) @@ -1517,7 +1515,7 @@ def sym_diff(self, other, result_name=None): Parameters ---------- - other : array-like + other : Index or array-like result_name : str Returns @@ -1545,13 +1543,10 @@ def sym_diff(self, other, result_name=None): >>> idx1 ^ idx2 Int64Index([1, 5], dtype='int64') """ - if not hasattr(other, '__iter__'): - raise TypeError('Input must be iterable!') - - if not isinstance(other, Index): - other = Index(other) - result_name = result_name or self.name - + self._assert_can_do_setop(other) + other, result_name_update = self._convert_can_do_setop(other) + if result_name is None: + result_name = result_name_update the_diff = sorted(set((self.difference(other)).union(other.difference(self)))) return Index(the_diff, name=result_name) @@ -5460,12 +5455,11 @@ def union(self, other): >>> index.union(index2) """ self._assert_can_do_setop(other) + other, result_names = self._convert_can_do_setop(other) if len(other) == 0 or self.equals(other): return self - result_names = self.names if self.names == other.names else None - uniq_tuples = lib.fast_unique_multiple([self.values, other.values]) return MultiIndex.from_arrays(lzip(*uniq_tuples), sortorder=0, names=result_names) @@ -5483,12 +5477,11 @@ def intersection(self, other): Index """ self._assert_can_do_setop(other) + other, result_names = self._convert_can_do_setop(other) if self.equals(other): return self - result_names = self.names if self.names == other.names else None - self_tuples = self.values other_tuples = other.values uniq_tuples = sorted(set(self_tuples) & set(other_tuples)) @@ -5509,18 +5502,10 @@ def difference(self, other): diff : MultiIndex """ self._assert_can_do_setop(other) + other, result_names = self._convert_can_do_setop(other) - if not isinstance(other, MultiIndex): - if len(other) == 0: + if len(other) == 0: return self - try: - other = MultiIndex.from_tuples(other) - except: - raise TypeError('other must be a MultiIndex or a list of' - ' tuples') - result_names = self.names - else: - result_names = self.names if self.names == other.names else None if self.equals(other): return MultiIndex(levels=[[]] * self.nlevels, @@ -5537,15 +5522,30 @@ def difference(self, other): return MultiIndex.from_tuples(difference, sortorder=0, names=result_names) - def _assert_can_do_setop(self, other): - pass - def astype(self, dtype): if not is_object_dtype(np.dtype(dtype)): raise TypeError('Setting %s dtype to anything other than object ' 'is not supported' % self.__class__) return self._shallow_copy() + def _convert_can_do_setop(self, other): + result_names = self.names + + if not hasattr(other, 'names'): + if len(other) == 0: + other = MultiIndex(levels=[[]] * self.nlevels, + labels=[[]] * self.nlevels, + verify_integrity=False) + else: + msg = 'other must be a MultiIndex or a list of tuples' + try: + other = MultiIndex.from_tuples(other) + except: + raise TypeError(msg) + else: + result_names = self.names if self.names == other.names else None + return other, result_names + def insert(self, loc, item): """ Make new MultiIndex inserting new item at location diff --git a/pandas/tests/test_index.py b/pandas/tests/test_index.py index 93299292cf353..ed84c9764dd84 100644 --- a/pandas/tests/test_index.py +++ b/pandas/tests/test_index.py @@ -251,6 +251,136 @@ def test_take(self): expected = ind[indexer] self.assertTrue(result.equals(expected)) + def test_setops_errorcases(self): + for name, idx in compat.iteritems(self.indices): + # # non-iterable input + cases = [0.5, 'xxx'] + methods = [idx.intersection, idx.union, idx.difference, idx.sym_diff] + + for method in methods: + for case in cases: + assertRaisesRegexp(TypeError, + "Input must be Index or array-like", + method, case) + + def test_intersection_base(self): + for name, idx in compat.iteritems(self.indices): + first = idx[:5] + second = idx[:3] + intersect = first.intersection(second) + + if isinstance(idx, CategoricalIndex): + pass + else: + self.assertTrue(tm.equalContents(intersect, second)) + + # GH 10149 + cases = [klass(second.values) for klass in [np.array, Series, list]] + for case in cases: + if isinstance(idx, PeriodIndex): + msg = "can only call with other PeriodIndex-ed objects" + with tm.assertRaisesRegexp(ValueError, msg): + result = first.intersection(case) + elif isinstance(idx, CategoricalIndex): + pass + else: + result = first.intersection(case) + self.assertTrue(tm.equalContents(result, second)) + + if isinstance(idx, MultiIndex): + msg = "other must be a MultiIndex or a list of tuples" + with tm.assertRaisesRegexp(TypeError, msg): + result = first.intersection([1, 2, 3]) + + def test_union_base(self): + for name, idx in compat.iteritems(self.indices): + first = idx[3:] + second = idx[:5] + everything = idx + union = first.union(second) + self.assertTrue(tm.equalContents(union, everything)) + + # GH 10149 + cases = [klass(second.values) for klass in [np.array, Series, list]] + for case in cases: + if isinstance(idx, PeriodIndex): + msg = "can only call with other PeriodIndex-ed objects" + with tm.assertRaisesRegexp(ValueError, msg): + result = first.union(case) + elif isinstance(idx, CategoricalIndex): + pass + else: + result = first.union(case) + self.assertTrue(tm.equalContents(result, everything)) + + if isinstance(idx, MultiIndex): + msg = "other must be a MultiIndex or a list of tuples" + with tm.assertRaisesRegexp(TypeError, msg): + result = first.union([1, 2, 3]) + + def test_difference_base(self): + for name, idx in compat.iteritems(self.indices): + first = idx[2:] + second = idx[:4] + answer = idx[4:] + result = first.difference(second) + + if isinstance(idx, CategoricalIndex): + pass + else: + self.assertTrue(tm.equalContents(result, answer)) + + # GH 10149 + cases = [klass(second.values) for klass in [np.array, Series, list]] + for case in cases: + if isinstance(idx, PeriodIndex): + msg = "can only call with other PeriodIndex-ed objects" + with tm.assertRaisesRegexp(ValueError, msg): + result = first.difference(case) + elif isinstance(idx, CategoricalIndex): + pass + elif isinstance(idx, (DatetimeIndex, TimedeltaIndex)): + self.assertEqual(result.__class__, answer.__class__) + self.assert_numpy_array_equal(result.asi8, answer.asi8) + else: + result = first.difference(case) + self.assertTrue(tm.equalContents(result, answer)) + + if isinstance(idx, MultiIndex): + msg = "other must be a MultiIndex or a list of tuples" + with tm.assertRaisesRegexp(TypeError, msg): + result = first.difference([1, 2, 3]) + + def test_symmetric_diff(self): + for name, idx in compat.iteritems(self.indices): + first = idx[1:] + second = idx[:-1] + if isinstance(idx, CategoricalIndex): + pass + else: + answer = idx[[0, -1]] + result = first.sym_diff(second) + self.assertTrue(tm.equalContents(result, answer)) + + # GH 10149 + cases = [klass(second.values) for klass in [np.array, Series, list]] + for case in cases: + if isinstance(idx, PeriodIndex): + msg = "can only call with other PeriodIndex-ed objects" + with tm.assertRaisesRegexp(ValueError, msg): + result = first.sym_diff(case) + elif isinstance(idx, CategoricalIndex): + pass + else: + result = first.sym_diff(case) + self.assertTrue(tm.equalContents(result, answer)) + + if isinstance(idx, MultiIndex): + msg = "other must be a MultiIndex or a list of tuples" + with tm.assertRaisesRegexp(TypeError, msg): + result = first.sym_diff([1, 2, 3]) + + class TestIndex(Base, tm.TestCase): _holder = Index _multiprocess_can_split_ = True @@ -620,16 +750,12 @@ def test_intersection(self): first = self.strIndex[:20] second = self.strIndex[:10] intersect = first.intersection(second) - self.assertTrue(tm.equalContents(intersect, second)) # Corner cases inter = first.intersection(first) self.assertIs(inter, first) - # non-iterable input - assertRaisesRegexp(TypeError, "iterable", first.intersection, 0.5) - idx1 = Index([1, 2, 3, 4, 5], name='idx') # if target has the same name, it is preserved idx2 = Index([3, 4, 5, 6, 7], name='idx') @@ -671,6 +797,12 @@ def test_union(self): union = first.union(second) self.assertTrue(tm.equalContents(union, everything)) + # GH 10149 + cases = [klass(second.values) for klass in [np.array, Series, list]] + for case in cases: + result = first.union(case) + self.assertTrue(tm.equalContents(result, everything)) + # Corner cases union = first.union(first) self.assertIs(union, first) @@ -681,9 +813,6 @@ def test_union(self): union = Index([]).union(first) self.assertIs(union, first) - # non-iterable input - assertRaisesRegexp(TypeError, "iterable", first.union, 0.5) - # preserve names first.name = 'A' second.name = 'A' @@ -792,11 +921,7 @@ def test_difference(self): self.assertEqual(len(result), 0) self.assertEqual(result.name, first.name) - # non-iterable input - assertRaisesRegexp(TypeError, "iterable", first.difference, 0.5) - def test_symmetric_diff(self): - # smoke idx1 = Index([1, 2, 3, 4], name='idx1') idx2 = Index([2, 3, 4, 5]) @@ -842,10 +967,6 @@ def test_symmetric_diff(self): self.assertTrue(tm.equalContents(result, expected)) self.assertEqual(result.name, 'new_name') - # other isn't iterable - with tm.assertRaises(TypeError): - Index(idx1,dtype='object').difference(1) - def test_is_numeric(self): self.assertFalse(self.dateIndex.is_numeric()) self.assertFalse(self.strIndex.is_numeric()) @@ -1786,6 +1907,7 @@ def test_equals(self): self.assertFalse(CategoricalIndex(list('aabca') + [np.nan],categories=['c','a','b',np.nan]).equals(list('aabca'))) self.assertTrue(CategoricalIndex(list('aabca') + [np.nan],categories=['c','a','b',np.nan]).equals(list('aabca') + [np.nan])) + class Numeric(Base): def test_numeric_compat(self): @@ -2661,6 +2783,36 @@ def test_time_overflow_for_32bit_machines(self): idx2 = pd.date_range(end='2000', periods=periods, freq='S') self.assertEqual(len(idx2), periods) + def test_intersection(self): + first = self.index + second = self.index[5:] + intersect = first.intersection(second) + self.assertTrue(tm.equalContents(intersect, second)) + + # GH 10149 + cases = [klass(second.values) for klass in [np.array, Series, list]] + for case in cases: + result = first.intersection(case) + self.assertTrue(tm.equalContents(result, second)) + + third = Index(['a', 'b', 'c']) + result = first.intersection(third) + expected = pd.Index([], dtype=object) + self.assert_index_equal(result, expected) + + def test_union(self): + first = self.index[:5] + second = self.index[5:] + everything = self.index + union = first.union(second) + self.assertTrue(tm.equalContents(union, everything)) + + # GH 10149 + cases = [klass(second.values) for klass in [np.array, Series, list]] + for case in cases: + result = first.union(case) + self.assertTrue(tm.equalContents(result, everything)) + class TestPeriodIndex(DatetimeLike, tm.TestCase): _holder = PeriodIndex @@ -2671,7 +2823,7 @@ def setUp(self): self.setup_indices() def create_index(self): - return period_range('20130101',periods=5,freq='D') + return period_range('20130101', periods=5, freq='D') def test_pickle_compat_construction(self): pass diff --git a/pandas/tseries/index.py b/pandas/tseries/index.py index bd0869b9525b7..745c536914e47 100644 --- a/pandas/tseries/index.py +++ b/pandas/tseries/index.py @@ -804,6 +804,7 @@ def union(self, other): ------- y : Index or DatetimeIndex """ + self._assert_can_do_setop(other) if not isinstance(other, DatetimeIndex): try: other = DatetimeIndex(other) @@ -1039,6 +1040,7 @@ def intersection(self, other): ------- y : Index or DatetimeIndex """ + self._assert_can_do_setop(other) if not isinstance(other, DatetimeIndex): try: other = DatetimeIndex(other) diff --git a/pandas/tseries/period.py b/pandas/tseries/period.py index 510887a185054..6627047f0c335 100644 --- a/pandas/tseries/period.py +++ b/pandas/tseries/period.py @@ -679,6 +679,8 @@ def join(self, other, how='left', level=None, return_indexers=False): return self._apply_meta(result) def _assert_can_do_setop(self, other): + super(PeriodIndex, self)._assert_can_do_setop(other) + if not isinstance(other, PeriodIndex): raise ValueError('can only call with other PeriodIndex-ed objects') diff --git a/pandas/tseries/tdi.py b/pandas/tseries/tdi.py index 1443c22909689..de68dd763d68c 100644 --- a/pandas/tseries/tdi.py +++ b/pandas/tseries/tdi.py @@ -436,12 +436,12 @@ def union(self, other): ------- y : Index or TimedeltaIndex """ - if _is_convertible_to_index(other): + self._assert_can_do_setop(other) + if not isinstance(other, TimedeltaIndex): try: other = TimedeltaIndex(other) - except TypeError: + except (TypeError, ValueError): pass - this, other = self, other if this._can_fast_union(other): @@ -581,6 +581,7 @@ def intersection(self, other): ------- y : Index or TimedeltaIndex """ + self._assert_can_do_setop(other) if not isinstance(other, TimedeltaIndex): try: other = TimedeltaIndex(other)