From 899cb05b42d32082a04e87b8be9d82fffc3757c5 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Wed, 29 Jan 2020 11:39:35 -0800 Subject: [PATCH 1/3] stricten, tests --- pandas/core/indexes/datetimelike.py | 7 ++- pandas/core/indexes/period.py | 66 ++++++++++++++++------ pandas/tests/indexes/period/test_period.py | 8 ++- 3 files changed, 60 insertions(+), 21 deletions(-) diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index b87dd0f02252f..9a1b716d0e776 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -466,7 +466,9 @@ def where(self, cond, other=None): other = other.view("i8") result = np.where(cond, values, other).astype("i8") - return self._shallow_copy(result) + arr = type(self._data)._simple_new(result, dtype=self.dtype) + return type(self)._simple_new(arr, name=self.name) + # TODO: were we returning incorrect freq by using shallow_copy? def _summary(self, name=None): """ @@ -570,7 +572,8 @@ def delete(self, loc): if loc.start in (0, None) or loc.stop in (len(self), None): freq = self.freq - return self._shallow_copy(new_i8s, freq=freq) + arr = type(self._data)._simple_new(new_i8s, dtype=self.dtype, freq=freq) + return type(self)._simple_new(arr, name=self.name) class DatetimeTimedeltaMixin(DatetimeIndexOpsMixin, Int64Index): diff --git a/pandas/core/indexes/period.py b/pandas/core/indexes/period.py index 1e18c16d02784..bf7eaf23321cf 100644 --- a/pandas/core/indexes/period.py +++ b/pandas/core/indexes/period.py @@ -265,15 +265,13 @@ def _has_complex_internals(self): return True def _shallow_copy(self, values=None, **kwargs): - # TODO: simplify, figure out type of values if values is None: values = self._data - if isinstance(values, type(self)): - values = values._data - if not isinstance(values, PeriodArray): if isinstance(values, np.ndarray) and values.dtype == "i8": + if "freq" in kwargs and kwargs["freq"] != self.freq: + raise ValueError(kwargs, self.freq) values = PeriodArray(values, freq=self.freq) else: # GH#30713 this should never be reached @@ -281,12 +279,9 @@ def _shallow_copy(self, values=None, **kwargs): # We don't allow changing `freq` in _shallow_copy. validate_dtype_freq(self.dtype, kwargs.get("freq")) - attributes = self._get_attributes_dict() + name = kwargs.get("name", self.name) - attributes.update(kwargs) - if not len(values) and "dtype" not in kwargs: - attributes["dtype"] = self.dtype - return self._simple_new(values, **attributes) + return self._simple_new(values, name=name) def _shallow_copy_with_infer(self, values=None, **kwargs): """ we always want to return a PeriodIndex """ @@ -386,6 +381,40 @@ def __contains__(self, key: Any) -> bool: def _int64index(self): return Int64Index._simple_new(self.asi8, name=self.name) + # ------------------------------------------------------------------------ + # NDarray-Like Methods + + def putmask(self, mask, value): + """ + Return a new Index of the values set with the mask. + + Returns + ------- + Index + + See Also + -------- + numpy.putmask + """ + if isinstance(value, Period): + if value.freq != self.freq: + return self.astype(object).putmask(mask, value) + i8val = value.ordinal + elif value is NaT: + i8val = val.value + elif isinstance(value, (PeriodArray, PeriodIndex)): + if value.freq != self.freq: + return self.astype(object).putmask(mask, value) + i8val = value.asi8 + else: + return self.astype(object).putmask(mask, value) + + i8values = self._data.copy()._data + + np.putmask(i8values, mask, i8val) + parr = PeriodArray(i8values, dtype=self.dtype) + return type(self)._simple_new(parr, name=self.name) + # ------------------------------------------------------------------------ # Index Methods @@ -702,10 +731,10 @@ def insert(self, loc, item): if not isinstance(item, Period) or self.freq != item.freq: return self.astype(object).insert(loc, item) - idx = np.concatenate( - (self[:loc].asi8, np.array([item.ordinal]), self[loc:].asi8) - ) - return self._shallow_copy(idx) + item_arr = type(self._data)._from_sequence([item]) + to_concat = [self[:loc]._data, item_arr, self[loc:]._data] + arr = type(self._data)._concat_same_type(to_concat) + return type(self)._simple_new(arr, name=self.name) def join(self, other, how="left", level=None, return_indexers=False, sort=False): """ @@ -762,7 +791,8 @@ def intersection(self, other, sort=False): i8other = Int64Index._simple_new(other.asi8) i8result = i8self.intersection(i8other, sort=sort) - result = self._shallow_copy(np.asarray(i8result, dtype=np.int64), name=res_name) + parr = type(self._data)(np.asarray(i8result, dtype=np.int64), dtype=self.dtype) + result = type(self)._simple_new(parr, name=res_name) return result def difference(self, other, sort=None): @@ -773,7 +803,7 @@ def difference(self, other, sort=None): if self.equals(other): # pass an empty PeriodArray with the appropriate dtype - return self._shallow_copy(self._data[:0]) + return type(self)._simple_new(self._data[:0], name=self.name) if is_object_dtype(other): return self.astype(object).difference(other).astype(self.dtype) @@ -785,7 +815,8 @@ def difference(self, other, sort=None): i8other = Int64Index._simple_new(other.asi8) i8result = i8self.difference(i8other, sort=sort) - result = self._shallow_copy(np.asarray(i8result, dtype=np.int64), name=res_name) + parr = type(self._data)(np.asarray(i8result, dtype=np.int64), dtype=self.dtype) + result = type(self)._simple_new(parr, name=res_name) return result def _union(self, other, sort): @@ -805,7 +836,8 @@ def _union(self, other, sort): i8result = i8self._union(i8other, sort=sort) res_name = get_op_result_name(self, other) - result = self._shallow_copy(np.asarray(i8result, dtype=np.int64), name=res_name) + parr = type(self._data)(np.asarray(i8result, dtype=np.int64), dtype=self.dtype) + result = type(self)._simple_new(parr, name=res_name) return result # ------------------------------------------------------------------------ diff --git a/pandas/tests/indexes/period/test_period.py b/pandas/tests/indexes/period/test_period.py index 16fa0b0c25925..9f6d864dcac86 100644 --- a/pandas/tests/indexes/period/test_period.py +++ b/pandas/tests/indexes/period/test_period.py @@ -117,7 +117,6 @@ def test_make_time_series(self): assert isinstance(series, Series) def test_shallow_copy_empty(self): - # GH13067 idx = PeriodIndex([], freq="M") result = idx._shallow_copy() @@ -131,11 +130,16 @@ def test_shallow_copy_i8(self): result = pi._shallow_copy(pi.asi8, freq=pi.freq) tm.assert_index_equal(result, pi) + def test_shallow_copy_requires_disallow_period_index(self): + pi = period_range("2018-01-01", periods=3, freq="2D") + with pytest.raises(TypeError, match="PeriodIndex"): + pi._shallow_copy(pi) + def test_shallow_copy_changing_freq_raises(self): pi = period_range("2018-01-01", periods=3, freq="2D") msg = "specified freq and dtype are different" with pytest.raises(IncompatibleFrequency, match=msg): - pi._shallow_copy(pi, freq="H") + pi._shallow_copy(pi._data, freq="H") def test_view_asi8(self): idx = pd.PeriodIndex([], freq="M") From 42de183ec17fbb60e15a685ae10da133e263bcd3 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Wed, 26 Feb 2020 13:45:21 -0800 Subject: [PATCH 2/3] Test for where --- pandas/core/indexes/base.py | 3 ++ pandas/core/indexes/datetimelike.py | 1 - pandas/core/indexes/period.py | 36 +------------------ .../tests/indexes/datetimes/test_indexing.py | 8 +++++ .../tests/indexes/timedeltas/test_indexing.py | 8 +++++ 5 files changed, 20 insertions(+), 36 deletions(-) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index b5e323fbd0fa4..a1681322a4a5f 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -4234,6 +4234,9 @@ def putmask(self, mask, value): values = self.values.copy() try: np.putmask(values, mask, self._convert_for_op(value)) + if is_period_dtype(self.dtype): + # .values cast to object, so we need to cast back + values = type(self)(values)._data return self._shallow_copy(values) except (ValueError, TypeError) as err: if is_object_dtype(self): diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index d9eeb3b1c9a1c..f9ce8eb6d720d 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -522,7 +522,6 @@ def where(self, cond, other=None): result = np.where(cond, values, other).astype("i8") arr = type(self._data)._simple_new(result, dtype=self.dtype) return type(self)._simple_new(arr, name=self.name) - # TODO: were we returning incorrect freq by using shallow_copy? def _summary(self, name=None) -> str: """ diff --git a/pandas/core/indexes/period.py b/pandas/core/indexes/period.py index b7cc7ae56fca9..de366ddd582bf 100644 --- a/pandas/core/indexes/period.py +++ b/pandas/core/indexes/period.py @@ -4,7 +4,7 @@ import numpy as np -from pandas._libs import NaT, index as libindex +from pandas._libs import index as libindex from pandas._libs.lib import no_default from pandas._libs.tslibs import frequencies as libfrequencies, resolution from pandas._libs.tslibs.parsing import parse_time_string @@ -335,40 +335,6 @@ def __contains__(self, key: Any) -> bool: def _int64index(self) -> Int64Index: return Int64Index._simple_new(self.asi8, name=self.name) - # ------------------------------------------------------------------------ - # NDarray-Like Methods - - def putmask(self, mask, value): - """ - Return a new Index of the values set with the mask. - - Returns - ------- - Index - - See Also - -------- - numpy.putmask - """ - if isinstance(value, Period): - if value.freq != self.freq: - return self.astype(object).putmask(mask, value) - i8val = value.ordinal - elif value is NaT: - i8val = value.value - elif isinstance(value, (PeriodArray, PeriodIndex)): - if value.freq != self.freq: - return self.astype(object).putmask(mask, value) - i8val = value.asi8 - else: - return self.astype(object).putmask(mask, value) - - i8values = self._data.copy()._data - - np.putmask(i8values, mask, i8val) - parr = PeriodArray(i8values, dtype=self.dtype) - return type(self)._simple_new(parr, name=self.name) - # ------------------------------------------------------------------------ # Index Methods diff --git a/pandas/tests/indexes/datetimes/test_indexing.py b/pandas/tests/indexes/datetimes/test_indexing.py index ceab670fb5041..554ae76979ba8 100644 --- a/pandas/tests/indexes/datetimes/test_indexing.py +++ b/pandas/tests/indexes/datetimes/test_indexing.py @@ -121,6 +121,14 @@ def test_dti_custom_getitem_matplotlib_hackaround(self): class TestWhere: + def test_where_doesnt_retain_freq(self): + dti = date_range("20130101", periods=3, freq="D", name="idx") + cond = [True, True, False] + expected = DatetimeIndex([dti[0], dti[1], dti[0]], freq=None, name="idx") + + result = dti.where(cond, dti[::-1]) + tm.assert_index_equal(result, expected) + def test_where_other(self): # other is ndarray or Index i = pd.date_range("20130101", periods=3, tz="US/Eastern") diff --git a/pandas/tests/indexes/timedeltas/test_indexing.py b/pandas/tests/indexes/timedeltas/test_indexing.py index 14fff6f9c85b5..5dec799832291 100644 --- a/pandas/tests/indexes/timedeltas/test_indexing.py +++ b/pandas/tests/indexes/timedeltas/test_indexing.py @@ -66,6 +66,14 @@ def test_timestamp_invalid_key(self, key): class TestWhere: + def test_where_doesnt_retain_freq(self): + tdi = timedelta_range("1 day", periods=3, freq="D", name="idx") + cond = [True, True, False] + expected = TimedeltaIndex([tdi[0], tdi[1], tdi[0]], freq=None, name="idx") + + result = tdi.where(cond, tdi[::-1]) + tm.assert_index_equal(result, expected) + def test_where_invalid_dtypes(self): tdi = timedelta_range("1 day", periods=3, freq="D", name="idx") From 9cb67b4c3914767a31fa2f5998047af790f09133 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Wed, 26 Feb 2020 14:15:28 -0800 Subject: [PATCH 3/3] semi-simplify --- pandas/core/indexes/period.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pandas/core/indexes/period.py b/pandas/core/indexes/period.py index de366ddd582bf..017f104e18493 100644 --- a/pandas/core/indexes/period.py +++ b/pandas/core/indexes/period.py @@ -607,9 +607,10 @@ def insert(self, loc, item): if not isinstance(item, Period) or self.freq != item.freq: return self.astype(object).insert(loc, item) - item_arr = type(self._data)._from_sequence([item]) - to_concat = [self[:loc]._data, item_arr, self[loc:]._data] - arr = type(self._data)._concat_same_type(to_concat) + i8result = np.concatenate( + (self[:loc].asi8, np.array([item.ordinal]), self[loc:].asi8) + ) + arr = type(self._data)._simple_new(i8result, dtype=self.dtype) return type(self)._simple_new(arr, name=self.name) def join(self, other, how="left", level=None, return_indexers=False, sort=False):