diff --git a/pandas/core/indexes/range.py b/pandas/core/indexes/range.py index f1054635f44db..4def2e4b93553 100644 --- a/pandas/core/indexes/range.py +++ b/pandas/core/indexes/range.py @@ -675,6 +675,9 @@ def _difference(self, other, sort=None): if not isinstance(other, RangeIndex): return super()._difference(other, sort=sort) + if sort is None and self.step < 0: + return self[::-1]._difference(other) + res_name = ops.get_op_result_name(self, other) first = self._range[::-1] if self.step < 0 else self._range @@ -683,36 +686,60 @@ def _difference(self, other, sort=None): overlap = overlap[::-1] if len(overlap) == 0: - result = self.rename(name=res_name) - if sort is None and self.step < 0: - result = result[::-1] - return result + return self.rename(name=res_name) if len(overlap) == len(self): return self[:0].rename(res_name) - if not isinstance(overlap, RangeIndex): - # We won't end up with RangeIndex, so fall back - return super()._difference(other, sort=sort) - if overlap.step != first.step: - # In some cases we might be able to get a RangeIndex back, - # but not worth the effort. - return super()._difference(other, sort=sort) - if overlap[0] == first.start: - # The difference is everything after the intersection - new_rng = range(overlap[-1] + first.step, first.stop, first.step) - elif overlap[-1] == first[-1]: - # The difference is everything before the intersection - new_rng = range(first.start, overlap[0], first.step) + # overlap.step will always be a multiple of self.step (see _intersection) + + if len(overlap) == 1: + if overlap[0] == self[0]: + return self[1:] + + elif overlap[0] == self[-1]: + return self[:-1] + + elif len(self) == 3 and overlap[0] == self[1]: + return self[::2] + + else: + return super()._difference(other, sort=sort) + + if overlap.step == first.step: + if overlap[0] == first.start: + # The difference is everything after the intersection + new_rng = range(overlap[-1] + first.step, first.stop, first.step) + elif overlap[-1] == first[-1]: + # The difference is everything before the intersection + new_rng = range(first.start, overlap[0], first.step) + else: + # The difference is not range-like + # e.g. range(1, 10, 1) and range(3, 7, 1) + return super()._difference(other, sort=sort) + else: - # The difference is not range-like + # We must have len(self) > 1, bc we ruled out above + # len(overlap) == 0 and len(overlap) == len(self) + assert len(self) > 1 + + if overlap.step == first.step * 2: + if overlap[0] == first[0] and overlap[-1] in (first[-1], first[-2]): + # e.g. range(1, 10, 1) and range(1, 10, 2) + return self[1::2] + + elif overlap[0] == first[1] and overlap[-1] in (first[-1], first[-2]): + # e.g. range(1, 10, 1) and range(2, 10, 2) + return self[::2] + + # We can get here with e.g. range(20) and range(0, 10, 2) + + # e.g. range(10) and range(0, 10, 3) return super()._difference(other, sort=sort) new_index = type(self)._simple_new(new_rng, name=res_name) if first is not self._range: new_index = new_index[::-1] - if sort is None and new_index.step < 0: - new_index = new_index[::-1] return new_index def symmetric_difference(self, other, result_name: Hashable = None, sort=None): diff --git a/pandas/tests/indexes/ranges/test_setops.py b/pandas/tests/indexes/ranges/test_setops.py index 53ea11345328c..583391bd96a85 100644 --- a/pandas/tests/indexes/ranges/test_setops.py +++ b/pandas/tests/indexes/ranges/test_setops.py @@ -3,6 +3,11 @@ timedelta, ) +from hypothesis import ( + assume, + given, + strategies as st, +) import numpy as np import pytest @@ -359,11 +364,44 @@ def test_difference_mismatched_step(self): obj = RangeIndex.from_range(range(1, 10), name="foo") result = obj.difference(obj[::2]) - expected = Int64Index(obj[1::2]._values, name=obj.name) + expected = obj[1::2] tm.assert_index_equal(result, expected, exact=True) + result = obj[::-1].difference(obj[::2], sort=False) + tm.assert_index_equal(result, expected[::-1], exact=True) + result = obj.difference(obj[1::2]) - expected = Int64Index(obj[::2]._values, name=obj.name) + expected = obj[::2] + tm.assert_index_equal(result, expected, exact=True) + + result = obj[::-1].difference(obj[1::2], sort=False) + tm.assert_index_equal(result, expected[::-1], exact=True) + + def test_difference_interior_non_preserving(self): + # case with intersection of length 1 but RangeIndex is not preserved + idx = Index(range(10)) + + other = idx[3:4] + result = idx.difference(other) + expected = Int64Index([0, 1, 2, 4, 5, 6, 7, 8, 9]) + tm.assert_index_equal(result, expected, exact=True) + + # case with other.step / self.step > 2 + other = idx[::3] + result = idx.difference(other) + expected = Int64Index([1, 2, 4, 5, 7, 8]) + tm.assert_index_equal(result, expected, exact=True) + + # cases with only reaching one end of left + obj = Index(range(20)) + other = obj[:10:2] + result = obj.difference(other) + expected = Int64Index([1, 3, 5, 7, 9] + list(range(10, 20))) + tm.assert_index_equal(result, expected, exact=True) + + other = obj[1:11:2] + result = obj.difference(other) + expected = Int64Index([0, 2, 4, 6, 8, 10] + list(range(11, 20))) tm.assert_index_equal(result, expected, exact=True) def test_symmetric_difference(self): @@ -391,3 +429,44 @@ def test_symmetric_difference(self): result = left.symmetric_difference(right[1:]) expected = Int64Index([1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14]) tm.assert_index_equal(result, expected) + + +def assert_range_or_not_is_rangelike(index): + """ + Check that we either have a RangeIndex or that this index *cannot* + be represented as a RangeIndex. + """ + if not isinstance(index, RangeIndex) and len(index) > 0: + diff = index[:-1] - index[1:] + assert not (diff == diff[0]).all() + + +@given( + st.integers(-20, 20), + st.integers(-20, 20), + st.integers(-20, 20), + st.integers(-20, 20), + st.integers(-20, 20), + st.integers(-20, 20), +) +def test_range_difference(start1, stop1, step1, start2, stop2, step2): + # test that + # a) we match Int64Index.difference and + # b) we return RangeIndex whenever it is possible to do so. + assume(step1 != 0) + assume(step2 != 0) + + left = RangeIndex(start1, stop1, step1) + right = RangeIndex(start2, stop2, step2) + + result = left.difference(right, sort=None) + assert_range_or_not_is_rangelike(result) + + alt = Int64Index(left).difference(Int64Index(right), sort=None) + tm.assert_index_equal(result, alt, exact="equiv") + + result = left.difference(right, sort=False) + assert_range_or_not_is_rangelike(result) + + alt = Int64Index(left).difference(Int64Index(right), sort=False) + tm.assert_index_equal(result, alt, exact="equiv")