diff --git a/pandas/tests/arrays/interval/test_interval.py b/pandas/tests/arrays/interval/test_interval.py index 024721896cc58..be4b2c3e7e74c 100644 --- a/pandas/tests/arrays/interval/test_interval.py +++ b/pandas/tests/arrays/interval/test_interval.py @@ -229,178 +229,3 @@ def test_min_max(self, left_right_dtypes, index_or_series_or_array): res = arr_na.max(skipna=True) assert res == MAX assert type(res) == type(MAX) - - -# ---------------------------------------------------------------------------- -# Arrow interaction - - -def test_arrow_extension_type(): - pa = pytest.importorskip("pyarrow") - - from pandas.core.arrays.arrow.extension_types import ArrowIntervalType - - p1 = ArrowIntervalType(pa.int64(), "left") - p2 = ArrowIntervalType(pa.int64(), "left") - p3 = ArrowIntervalType(pa.int64(), "right") - - assert p1.closed == "left" - assert p1 == p2 - assert p1 != p3 - assert hash(p1) == hash(p2) - assert hash(p1) != hash(p3) - - -def test_arrow_array(): - pa = pytest.importorskip("pyarrow") - - from pandas.core.arrays.arrow.extension_types import ArrowIntervalType - - intervals = pd.interval_range(1, 5, freq=1).array - - result = pa.array(intervals) - assert isinstance(result.type, ArrowIntervalType) - assert result.type.closed == intervals.closed - assert result.type.subtype == pa.int64() - assert result.storage.field("left").equals(pa.array([1, 2, 3, 4], type="int64")) - assert result.storage.field("right").equals(pa.array([2, 3, 4, 5], type="int64")) - - expected = pa.array([{"left": i, "right": i + 1} for i in range(1, 5)]) - assert result.storage.equals(expected) - - # convert to its storage type - result = pa.array(intervals, type=expected.type) - assert result.equals(expected) - - # unsupported conversions - with pytest.raises(TypeError, match="Not supported to convert IntervalArray"): - pa.array(intervals, type="float64") - - with pytest.raises(TypeError, match="Not supported to convert IntervalArray"): - pa.array(intervals, type=ArrowIntervalType(pa.float64(), "left")) - - -def test_arrow_array_missing(): - pa = pytest.importorskip("pyarrow") - - from pandas.core.arrays.arrow.extension_types import ArrowIntervalType - - arr = IntervalArray.from_breaks([0.0, 1.0, 2.0, 3.0]) - arr[1] = None - - result = pa.array(arr) - assert isinstance(result.type, ArrowIntervalType) - assert result.type.closed == arr.closed - assert result.type.subtype == pa.float64() - - # fields have missing values (not NaN) - left = pa.array([0.0, None, 2.0], type="float64") - right = pa.array([1.0, None, 3.0], type="float64") - assert result.storage.field("left").equals(left) - assert result.storage.field("right").equals(right) - - # structarray itself also has missing values on the array level - vals = [ - {"left": 0.0, "right": 1.0}, - {"left": None, "right": None}, - {"left": 2.0, "right": 3.0}, - ] - expected = pa.StructArray.from_pandas(vals, mask=np.array([False, True, False])) - assert result.storage.equals(expected) - - -@pytest.mark.filterwarnings( - "ignore:Passing a BlockManager to DataFrame:DeprecationWarning" -) -@pytest.mark.parametrize( - "breaks", - [[0.0, 1.0, 2.0, 3.0], date_range("2017", periods=4, freq="D")], - ids=["float", "datetime64[ns]"], -) -def test_arrow_table_roundtrip(breaks): - pa = pytest.importorskip("pyarrow") - - from pandas.core.arrays.arrow.extension_types import ArrowIntervalType - - arr = IntervalArray.from_breaks(breaks) - arr[1] = None - df = pd.DataFrame({"a": arr}) - - table = pa.table(df) - assert isinstance(table.field("a").type, ArrowIntervalType) - result = table.to_pandas() - assert isinstance(result["a"].dtype, pd.IntervalDtype) - tm.assert_frame_equal(result, df) - - table2 = pa.concat_tables([table, table]) - result = table2.to_pandas() - expected = pd.concat([df, df], ignore_index=True) - tm.assert_frame_equal(result, expected) - - # GH-41040 - table = pa.table( - [pa.chunked_array([], type=table.column(0).type)], schema=table.schema - ) - result = table.to_pandas() - tm.assert_frame_equal(result, expected[0:0]) - - -@pytest.mark.filterwarnings( - "ignore:Passing a BlockManager to DataFrame:DeprecationWarning" -) -@pytest.mark.parametrize( - "breaks", - [[0.0, 1.0, 2.0, 3.0], date_range("2017", periods=4, freq="D")], - ids=["float", "datetime64[ns]"], -) -def test_arrow_table_roundtrip_without_metadata(breaks): - pa = pytest.importorskip("pyarrow") - - arr = IntervalArray.from_breaks(breaks) - arr[1] = None - df = pd.DataFrame({"a": arr}) - - table = pa.table(df) - # remove the metadata - table = table.replace_schema_metadata() - assert table.schema.metadata is None - - result = table.to_pandas() - assert isinstance(result["a"].dtype, pd.IntervalDtype) - tm.assert_frame_equal(result, df) - - -def test_from_arrow_from_raw_struct_array(): - # in case pyarrow lost the Interval extension type (eg on parquet roundtrip - # with datetime64[ns] subtype, see GH-45881), still allow conversion - # from arrow to IntervalArray - pa = pytest.importorskip("pyarrow") - - arr = pa.array([{"left": 0, "right": 1}, {"left": 1, "right": 2}]) - dtype = pd.IntervalDtype(np.dtype("int64"), closed="neither") - - result = dtype.__from_arrow__(arr) - expected = IntervalArray.from_breaks( - np.array([0, 1, 2], dtype="int64"), closed="neither" - ) - tm.assert_extension_array_equal(result, expected) - - result = dtype.__from_arrow__(pa.chunked_array([arr])) - tm.assert_extension_array_equal(result, expected) - - -@pytest.mark.parametrize("timezone", ["UTC", "US/Pacific", "GMT"]) -def test_interval_index_subtype(timezone, inclusive_endpoints_fixture): - # GH 46999 - dates = date_range("2022", periods=3, tz=timezone) - dtype = f"interval[datetime64[ns, {timezone}], {inclusive_endpoints_fixture}]" - result = IntervalIndex.from_arrays( - ["2022-01-01", "2022-01-02"], - ["2022-01-02", "2022-01-03"], - closed=inclusive_endpoints_fixture, - dtype=dtype, - ) - expected = IntervalIndex.from_arrays( - dates[:-1], dates[1:], closed=inclusive_endpoints_fixture - ) - tm.assert_index_equal(result, expected) diff --git a/pandas/tests/arrays/interval/test_interval_pyarrow.py b/pandas/tests/arrays/interval/test_interval_pyarrow.py new file mode 100644 index 0000000000000..ef8701be81e2b --- /dev/null +++ b/pandas/tests/arrays/interval/test_interval_pyarrow.py @@ -0,0 +1,160 @@ +import numpy as np +import pytest + +import pandas as pd +import pandas._testing as tm +from pandas.core.arrays import IntervalArray + + +def test_arrow_extension_type(): + pa = pytest.importorskip("pyarrow") + + from pandas.core.arrays.arrow.extension_types import ArrowIntervalType + + p1 = ArrowIntervalType(pa.int64(), "left") + p2 = ArrowIntervalType(pa.int64(), "left") + p3 = ArrowIntervalType(pa.int64(), "right") + + assert p1.closed == "left" + assert p1 == p2 + assert p1 != p3 + assert hash(p1) == hash(p2) + assert hash(p1) != hash(p3) + + +def test_arrow_array(): + pa = pytest.importorskip("pyarrow") + + from pandas.core.arrays.arrow.extension_types import ArrowIntervalType + + intervals = pd.interval_range(1, 5, freq=1).array + + result = pa.array(intervals) + assert isinstance(result.type, ArrowIntervalType) + assert result.type.closed == intervals.closed + assert result.type.subtype == pa.int64() + assert result.storage.field("left").equals(pa.array([1, 2, 3, 4], type="int64")) + assert result.storage.field("right").equals(pa.array([2, 3, 4, 5], type="int64")) + + expected = pa.array([{"left": i, "right": i + 1} for i in range(1, 5)]) + assert result.storage.equals(expected) + + # convert to its storage type + result = pa.array(intervals, type=expected.type) + assert result.equals(expected) + + # unsupported conversions + with pytest.raises(TypeError, match="Not supported to convert IntervalArray"): + pa.array(intervals, type="float64") + + with pytest.raises(TypeError, match="Not supported to convert IntervalArray"): + pa.array(intervals, type=ArrowIntervalType(pa.float64(), "left")) + + +def test_arrow_array_missing(): + pa = pytest.importorskip("pyarrow") + + from pandas.core.arrays.arrow.extension_types import ArrowIntervalType + + arr = IntervalArray.from_breaks([0.0, 1.0, 2.0, 3.0]) + arr[1] = None + + result = pa.array(arr) + assert isinstance(result.type, ArrowIntervalType) + assert result.type.closed == arr.closed + assert result.type.subtype == pa.float64() + + # fields have missing values (not NaN) + left = pa.array([0.0, None, 2.0], type="float64") + right = pa.array([1.0, None, 3.0], type="float64") + assert result.storage.field("left").equals(left) + assert result.storage.field("right").equals(right) + + # structarray itself also has missing values on the array level + vals = [ + {"left": 0.0, "right": 1.0}, + {"left": None, "right": None}, + {"left": 2.0, "right": 3.0}, + ] + expected = pa.StructArray.from_pandas(vals, mask=np.array([False, True, False])) + assert result.storage.equals(expected) + + +@pytest.mark.filterwarnings( + "ignore:Passing a BlockManager to DataFrame:DeprecationWarning" +) +@pytest.mark.parametrize( + "breaks", + [[0.0, 1.0, 2.0, 3.0], pd.date_range("2017", periods=4, freq="D")], + ids=["float", "datetime64[ns]"], +) +def test_arrow_table_roundtrip(breaks): + pa = pytest.importorskip("pyarrow") + + from pandas.core.arrays.arrow.extension_types import ArrowIntervalType + + arr = IntervalArray.from_breaks(breaks) + arr[1] = None + df = pd.DataFrame({"a": arr}) + + table = pa.table(df) + assert isinstance(table.field("a").type, ArrowIntervalType) + result = table.to_pandas() + assert isinstance(result["a"].dtype, pd.IntervalDtype) + tm.assert_frame_equal(result, df) + + table2 = pa.concat_tables([table, table]) + result = table2.to_pandas() + expected = pd.concat([df, df], ignore_index=True) + tm.assert_frame_equal(result, expected) + + # GH#41040 + table = pa.table( + [pa.chunked_array([], type=table.column(0).type)], schema=table.schema + ) + result = table.to_pandas() + tm.assert_frame_equal(result, expected[0:0]) + + +@pytest.mark.filterwarnings( + "ignore:Passing a BlockManager to DataFrame:DeprecationWarning" +) +@pytest.mark.parametrize( + "breaks", + [[0.0, 1.0, 2.0, 3.0], pd.date_range("2017", periods=4, freq="D")], + ids=["float", "datetime64[ns]"], +) +def test_arrow_table_roundtrip_without_metadata(breaks): + pa = pytest.importorskip("pyarrow") + + arr = IntervalArray.from_breaks(breaks) + arr[1] = None + df = pd.DataFrame({"a": arr}) + + table = pa.table(df) + # remove the metadata + table = table.replace_schema_metadata() + assert table.schema.metadata is None + + result = table.to_pandas() + assert isinstance(result["a"].dtype, pd.IntervalDtype) + tm.assert_frame_equal(result, df) + + +def test_from_arrow_from_raw_struct_array(): + # in case pyarrow lost the Interval extension type (eg on parquet roundtrip + # with datetime64[ns] subtype, see GH-45881), still allow conversion + # from arrow to IntervalArray + pa = pytest.importorskip("pyarrow") + + arr = pa.array([{"left": 0, "right": 1}, {"left": 1, "right": 2}]) + dtype = pd.IntervalDtype(np.dtype("int64"), closed="neither") + + result = dtype.__from_arrow__(arr) + expected = IntervalArray.from_breaks( + np.array([0, 1, 2], dtype="int64"), closed="neither" + ) + tm.assert_extension_array_equal(result, expected) + + result = dtype.__from_arrow__(pa.chunked_array([arr])) + tm.assert_extension_array_equal(result, expected) diff --git a/pandas/tests/indexes/interval/test_base.py b/pandas/tests/indexes/interval/test_base.py deleted file mode 100644 index e0155a13481ac..0000000000000 --- a/pandas/tests/indexes/interval/test_base.py +++ /dev/null @@ -1,56 +0,0 @@ -import numpy as np -import pytest - -from pandas import IntervalIndex -import pandas._testing as tm - - -class TestInterval: - """ - Tests specific to the shared common index tests; unrelated tests should be placed - in test_interval.py or the specific test file (e.g. test_astype.py) - """ - - @pytest.fixture - def simple_index(self) -> IntervalIndex: - return IntervalIndex.from_breaks(range(11), closed="right") - - @pytest.fixture - def index(self): - return tm.makeIntervalIndex(10) - - def test_take(self, closed): - index = IntervalIndex.from_breaks(range(11), closed=closed) - - result = index.take(range(10)) - tm.assert_index_equal(result, index) - - result = index.take([0, 0, 1]) - expected = IntervalIndex.from_arrays([0, 0, 1], [1, 1, 2], closed=closed) - tm.assert_index_equal(result, expected) - - def test_where(self, simple_index, listlike_box): - klass = listlike_box - - idx = simple_index - cond = [True] * len(idx) - expected = idx - result = expected.where(klass(cond)) - tm.assert_index_equal(result, expected) - - cond = [False] + [True] * len(idx[1:]) - expected = IntervalIndex([np.nan] + idx[1:].tolist()) - result = idx.where(klass(cond)) - tm.assert_index_equal(result, expected) - - def test_getitem_2d_deprecated(self, simple_index): - # GH#30588 multi-dim indexing is deprecated, but raising is also acceptable - idx = simple_index - with pytest.raises(ValueError, match="multi-dimensional indexing not allowed"): - idx[:, None] - with pytest.raises(ValueError, match="multi-dimensional indexing not allowed"): - # GH#44051 - idx[True] - with pytest.raises(ValueError, match="multi-dimensional indexing not allowed"): - # GH#44051 - idx[False] diff --git a/pandas/tests/indexes/interval/test_constructors.py b/pandas/tests/indexes/interval/test_constructors.py index 1efe5ff980f6c..078a0e06e0ed7 100644 --- a/pandas/tests/indexes/interval/test_constructors.py +++ b/pandas/tests/indexes/interval/test_constructors.py @@ -488,6 +488,23 @@ def test_index_mixed_closed(self): tm.assert_index_equal(result, expected) +@pytest.mark.parametrize("timezone", ["UTC", "US/Pacific", "GMT"]) +def test_interval_index_subtype(timezone, inclusive_endpoints_fixture): + # GH#46999 + dates = date_range("2022", periods=3, tz=timezone) + dtype = f"interval[datetime64[ns, {timezone}], {inclusive_endpoints_fixture}]" + result = IntervalIndex.from_arrays( + ["2022-01-01", "2022-01-02"], + ["2022-01-02", "2022-01-03"], + closed=inclusive_endpoints_fixture, + dtype=dtype, + ) + expected = IntervalIndex.from_arrays( + dates[:-1], dates[1:], closed=inclusive_endpoints_fixture + ) + tm.assert_index_equal(result, expected) + + def test_dtype_closed_mismatch(): # GH#38394 closed specified in both dtype and IntervalIndex constructor diff --git a/pandas/tests/indexes/interval/test_indexing.py b/pandas/tests/indexes/interval/test_indexing.py index db8f697b95cd8..2007a793843c9 100644 --- a/pandas/tests/indexes/interval/test_indexing.py +++ b/pandas/tests/indexes/interval/test_indexing.py @@ -19,12 +19,75 @@ array, date_range, interval_range, + isna, period_range, timedelta_range, ) import pandas._testing as tm +class TestGetItem: + def test_getitem(self, closed): + idx = IntervalIndex.from_arrays((0, 1, np.nan), (1, 2, np.nan), closed=closed) + assert idx[0] == Interval(0.0, 1.0, closed=closed) + assert idx[1] == Interval(1.0, 2.0, closed=closed) + assert isna(idx[2]) + + result = idx[0:1] + expected = IntervalIndex.from_arrays((0.0,), (1.0,), closed=closed) + tm.assert_index_equal(result, expected) + + result = idx[0:2] + expected = IntervalIndex.from_arrays((0.0, 1), (1.0, 2.0), closed=closed) + tm.assert_index_equal(result, expected) + + result = idx[1:3] + expected = IntervalIndex.from_arrays( + (1.0, np.nan), (2.0, np.nan), closed=closed + ) + tm.assert_index_equal(result, expected) + + def test_getitem_2d_deprecated(self): + # GH#30588 multi-dim indexing is deprecated, but raising is also acceptable + idx = IntervalIndex.from_breaks(range(11), closed="right") + with pytest.raises(ValueError, match="multi-dimensional indexing not allowed"): + idx[:, None] + with pytest.raises(ValueError, match="multi-dimensional indexing not allowed"): + # GH#44051 + idx[True] + with pytest.raises(ValueError, match="multi-dimensional indexing not allowed"): + # GH#44051 + idx[False] + + +class TestWhere: + def test_where(self, listlike_box): + klass = listlike_box + + idx = IntervalIndex.from_breaks(range(11), closed="right") + cond = [True] * len(idx) + expected = idx + result = expected.where(klass(cond)) + tm.assert_index_equal(result, expected) + + cond = [False] + [True] * len(idx[1:]) + expected = IntervalIndex([np.nan] + idx[1:].tolist()) + result = idx.where(klass(cond)) + tm.assert_index_equal(result, expected) + + +class TestTake: + def test_take(self, closed): + index = IntervalIndex.from_breaks(range(11), closed=closed) + + result = index.take(range(10)) + tm.assert_index_equal(result, index) + + result = index.take([0, 0, 1]) + expected = IntervalIndex.from_arrays([0, 0, 1], [1, 1, 2], closed=closed) + tm.assert_index_equal(result, expected) + + class TestGetLoc: @pytest.mark.parametrize("side", ["right", "left", "both", "neither"]) def test_get_loc_interval(self, closed, side): diff --git a/pandas/tests/indexes/interval/test_interval.py b/pandas/tests/indexes/interval/test_interval.py index dea40eff8d2ac..e19b1700236f5 100644 --- a/pandas/tests/indexes/interval/test_interval.py +++ b/pandas/tests/indexes/interval/test_interval.py @@ -341,26 +341,6 @@ def test_is_monotonic_with_nans(self): assert not index._is_strictly_monotonic_decreasing assert not index.is_monotonic_decreasing - def test_get_item(self, closed): - i = IntervalIndex.from_arrays((0, 1, np.nan), (1, 2, np.nan), closed=closed) - assert i[0] == Interval(0.0, 1.0, closed=closed) - assert i[1] == Interval(1.0, 2.0, closed=closed) - assert isna(i[2]) - - result = i[0:1] - expected = IntervalIndex.from_arrays((0.0,), (1.0,), closed=closed) - tm.assert_index_equal(result, expected) - - result = i[0:2] - expected = IntervalIndex.from_arrays((0.0, 1), (1.0, 2.0), closed=closed) - tm.assert_index_equal(result, expected) - - result = i[1:3] - expected = IntervalIndex.from_arrays( - (1.0, np.nan), (2.0, np.nan), closed=closed - ) - tm.assert_index_equal(result, expected) - @pytest.mark.parametrize( "breaks", [ diff --git a/pandas/tests/scalar/interval/test_arithmetic.py b/pandas/tests/scalar/interval/test_arithmetic.py index 863446c64de42..603763227cb88 100644 --- a/pandas/tests/scalar/interval/test_arithmetic.py +++ b/pandas/tests/scalar/interval/test_arithmetic.py @@ -8,56 +8,185 @@ Timedelta, Timestamp, ) +import pandas._testing as tm -@pytest.mark.parametrize("method", ["__add__", "__sub__"]) -@pytest.mark.parametrize( - "interval", - [ - Interval(Timestamp("2017-01-01 00:00:00"), Timestamp("2018-01-01 00:00:00")), - Interval(Timedelta(days=7), Timedelta(days=14)), - ], -) -@pytest.mark.parametrize( - "delta", [Timedelta(days=7), timedelta(7), np.timedelta64(7, "D")] -) -def test_time_interval_add_subtract_timedelta(interval, delta, method): - # https://github.com/pandas-dev/pandas/issues/32023 - result = getattr(interval, method)(delta) - left = getattr(interval.left, method)(delta) - right = getattr(interval.right, method)(delta) - expected = Interval(left, right) +class TestIntervalArithmetic: + def test_interval_add(self, closed): + interval = Interval(0, 1, closed=closed) + expected = Interval(1, 2, closed=closed) - assert result == expected + result = interval + 1 + assert result == expected + result = 1 + interval + assert result == expected -@pytest.mark.parametrize("interval", [Interval(1, 2), Interval(1.0, 2.0)]) -@pytest.mark.parametrize( - "delta", [Timedelta(days=7), timedelta(7), np.timedelta64(7, "D")] -) -def test_numeric_interval_add_timedelta_raises(interval, delta): - # https://github.com/pandas-dev/pandas/issues/32023 - msg = "|".join( + result = interval + result += 1 + assert result == expected + + msg = r"unsupported operand type\(s\) for \+" + with pytest.raises(TypeError, match=msg): + interval + interval + + with pytest.raises(TypeError, match=msg): + interval + "foo" + + def test_interval_sub(self, closed): + interval = Interval(0, 1, closed=closed) + expected = Interval(-1, 0, closed=closed) + + result = interval - 1 + assert result == expected + + result = interval + result -= 1 + assert result == expected + + msg = r"unsupported operand type\(s\) for -" + with pytest.raises(TypeError, match=msg): + interval - interval + + with pytest.raises(TypeError, match=msg): + interval - "foo" + + def test_interval_mult(self, closed): + interval = Interval(0, 1, closed=closed) + expected = Interval(0, 2, closed=closed) + + result = interval * 2 + assert result == expected + + result = 2 * interval + assert result == expected + + result = interval + result *= 2 + assert result == expected + + msg = r"unsupported operand type\(s\) for \*" + with pytest.raises(TypeError, match=msg): + interval * interval + + msg = r"can\'t multiply sequence by non-int" + with pytest.raises(TypeError, match=msg): + interval * "foo" + + def test_interval_div(self, closed): + interval = Interval(0, 1, closed=closed) + expected = Interval(0, 0.5, closed=closed) + + result = interval / 2.0 + assert result == expected + + result = interval + result /= 2.0 + assert result == expected + + msg = r"unsupported operand type\(s\) for /" + with pytest.raises(TypeError, match=msg): + interval / interval + + with pytest.raises(TypeError, match=msg): + interval / "foo" + + def test_interval_floordiv(self, closed): + interval = Interval(1, 2, closed=closed) + expected = Interval(0, 1, closed=closed) + + result = interval // 2 + assert result == expected + + result = interval + result //= 2 + assert result == expected + + msg = r"unsupported operand type\(s\) for //" + with pytest.raises(TypeError, match=msg): + interval // interval + + with pytest.raises(TypeError, match=msg): + interval // "foo" + + @pytest.mark.parametrize("method", ["__add__", "__sub__"]) + @pytest.mark.parametrize( + "interval", [ - "unsupported operand", - "cannot use operands", - "Only numeric, Timestamp and Timedelta endpoints are allowed", - ] + Interval( + Timestamp("2017-01-01 00:00:00"), Timestamp("2018-01-01 00:00:00") + ), + Interval(Timedelta(days=7), Timedelta(days=14)), + ], ) - with pytest.raises((TypeError, ValueError), match=msg): - interval + delta + @pytest.mark.parametrize( + "delta", [Timedelta(days=7), timedelta(7), np.timedelta64(7, "D")] + ) + def test_time_interval_add_subtract_timedelta(self, interval, delta, method): + # https://github.com/pandas-dev/pandas/issues/32023 + result = getattr(interval, method)(delta) + left = getattr(interval.left, method)(delta) + right = getattr(interval.right, method)(delta) + expected = Interval(left, right) + + assert result == expected + + @pytest.mark.parametrize("interval", [Interval(1, 2), Interval(1.0, 2.0)]) + @pytest.mark.parametrize( + "delta", [Timedelta(days=7), timedelta(7), np.timedelta64(7, "D")] + ) + def test_numeric_interval_add_timedelta_raises(self, interval, delta): + # https://github.com/pandas-dev/pandas/issues/32023 + msg = "|".join( + [ + "unsupported operand", + "cannot use operands", + "Only numeric, Timestamp and Timedelta endpoints are allowed", + ] + ) + with pytest.raises((TypeError, ValueError), match=msg): + interval + delta + + with pytest.raises((TypeError, ValueError), match=msg): + delta + interval + + @pytest.mark.parametrize("klass", [timedelta, np.timedelta64, Timedelta]) + def test_timedelta_add_timestamp_interval(self, klass): + delta = klass(0) + expected = Interval(Timestamp("2020-01-01"), Timestamp("2020-02-01")) + + result = delta + expected + assert result == expected + + result = expected + delta + assert result == expected - with pytest.raises((TypeError, ValueError), match=msg): - delta + interval +class TestIntervalComparisons: + def test_interval_equal(self): + assert Interval(0, 1) == Interval(0, 1, closed="right") + assert Interval(0, 1) != Interval(0, 1, closed="left") + assert Interval(0, 1) != 0 -@pytest.mark.parametrize("klass", [timedelta, np.timedelta64, Timedelta]) -def test_timedelta_add_timestamp_interval(klass): - delta = klass(0) - expected = Interval(Timestamp("2020-01-01"), Timestamp("2020-02-01")) + def test_interval_comparison(self): + msg = ( + "'<' not supported between instances of " + "'pandas._libs.interval.Interval' and 'int'" + ) + with pytest.raises(TypeError, match=msg): + Interval(0, 1) < 2 - result = delta + expected - assert result == expected + assert Interval(0, 1) < Interval(1, 2) + assert Interval(0, 1) < Interval(0, 2) + assert Interval(0, 1) < Interval(0.5, 1.5) + assert Interval(0, 1) <= Interval(0, 1) + assert Interval(0, 1) > Interval(-1, 2) + assert Interval(0, 1) >= Interval(0, 1) - result = expected + delta - assert result == expected + def test_equality_comparison_broadcasts_over_array(self): + # https://github.com/pandas-dev/pandas/issues/35931 + interval = Interval(0, 1) + arr = np.array([interval, interval]) + result = interval == arr + expected = np.array([True, True]) + tm.assert_numpy_array_equal(result, expected) diff --git a/pandas/tests/scalar/interval/test_constructors.py b/pandas/tests/scalar/interval/test_constructors.py new file mode 100644 index 0000000000000..a4bc00b923434 --- /dev/null +++ b/pandas/tests/scalar/interval/test_constructors.py @@ -0,0 +1,51 @@ +import pytest + +from pandas import ( + Interval, + Period, + Timestamp, +) + + +class TestIntervalConstructors: + @pytest.mark.parametrize( + "left, right", + [ + ("a", "z"), + (("a", "b"), ("c", "d")), + (list("AB"), list("ab")), + (Interval(0, 1), Interval(1, 2)), + (Period("2018Q1", freq="Q"), Period("2018Q1", freq="Q")), + ], + ) + def test_construct_errors(self, left, right): + # GH#23013 + msg = "Only numeric, Timestamp and Timedelta endpoints are allowed" + with pytest.raises(ValueError, match=msg): + Interval(left, right) + + def test_constructor_errors(self): + msg = "invalid option for 'closed': foo" + with pytest.raises(ValueError, match=msg): + Interval(0, 1, closed="foo") + + msg = "left side of interval must be <= right side" + with pytest.raises(ValueError, match=msg): + Interval(1, 0) + + @pytest.mark.parametrize( + "tz_left, tz_right", [(None, "UTC"), ("UTC", None), ("UTC", "US/Eastern")] + ) + def test_constructor_errors_tz(self, tz_left, tz_right): + # GH#18538 + left = Timestamp("2017-01-01", tz=tz_left) + right = Timestamp("2017-01-02", tz=tz_right) + + if tz_left is None or tz_right is None: + error = TypeError + msg = "Cannot compare tz-naive and tz-aware timestamps" + else: + error = ValueError + msg = "left and right must have the same time zone" + with pytest.raises(error, match=msg): + Interval(left, right) diff --git a/pandas/tests/scalar/interval/test_contains.py b/pandas/tests/scalar/interval/test_contains.py new file mode 100644 index 0000000000000..8dfca117a658b --- /dev/null +++ b/pandas/tests/scalar/interval/test_contains.py @@ -0,0 +1,73 @@ +import pytest + +from pandas import ( + Interval, + Timedelta, + Timestamp, +) + + +class TestContains: + def test_contains(self): + interval = Interval(0, 1) + assert 0.5 in interval + assert 1 in interval + assert 0 not in interval + + interval_both = Interval(0, 1, "both") + assert 0 in interval_both + assert 1 in interval_both + + interval_neither = Interval(0, 1, closed="neither") + assert 0 not in interval_neither + assert 0.5 in interval_neither + assert 1 not in interval_neither + + def test_contains_interval(self, inclusive_endpoints_fixture): + interval1 = Interval(0, 1, "both") + interval2 = Interval(0, 1, inclusive_endpoints_fixture) + assert interval1 in interval1 + assert interval2 in interval2 + assert interval2 in interval1 + assert interval1 not in interval2 or inclusive_endpoints_fixture == "both" + + def test_contains_infinite_length(self): + interval1 = Interval(0, 1, "both") + interval2 = Interval(float("-inf"), float("inf"), "neither") + assert interval1 in interval2 + assert interval2 not in interval1 + + def test_contains_zero_length(self): + interval1 = Interval(0, 1, "both") + interval2 = Interval(-1, -1, "both") + interval3 = Interval(0.5, 0.5, "both") + assert interval2 not in interval1 + assert interval3 in interval1 + assert interval2 not in interval3 and interval3 not in interval2 + assert interval1 not in interval2 and interval1 not in interval3 + + @pytest.mark.parametrize( + "type1", + [ + (0, 1), + (Timestamp(2000, 1, 1, 0), Timestamp(2000, 1, 1, 1)), + (Timedelta("0h"), Timedelta("1h")), + ], + ) + @pytest.mark.parametrize( + "type2", + [ + (0, 1), + (Timestamp(2000, 1, 1, 0), Timestamp(2000, 1, 1, 1)), + (Timedelta("0h"), Timedelta("1h")), + ], + ) + def test_contains_mixed_types(self, type1, type2): + interval1 = Interval(*type1) + interval2 = Interval(*type2) + if type1 == type2: + assert interval1 in interval2 + else: + msg = "^'<=' not supported between instances of" + with pytest.raises(TypeError, match=msg): + interval1 in interval2 diff --git a/pandas/tests/scalar/interval/test_formats.py b/pandas/tests/scalar/interval/test_formats.py new file mode 100644 index 0000000000000..6bf7aa91df3ce --- /dev/null +++ b/pandas/tests/scalar/interval/test_formats.py @@ -0,0 +1,11 @@ +from pandas import Interval + + +def test_interval_repr(): + interval = Interval(0, 1) + assert repr(interval) == "Interval(0, 1, closed='right')" + assert str(interval) == "(0, 1]" + + interval_left = Interval(0, 1, closed="left") + assert repr(interval_left) == "Interval(0, 1, closed='left')" + assert str(interval_left) == "[0, 1)" diff --git a/pandas/tests/scalar/interval/test_interval.py b/pandas/tests/scalar/interval/test_interval.py index 4841c488a5768..91b31e82f9c52 100644 --- a/pandas/tests/scalar/interval/test_interval.py +++ b/pandas/tests/scalar/interval/test_interval.py @@ -3,12 +3,9 @@ from pandas import ( Interval, - Period, Timedelta, Timestamp, ) -import pandas._testing as tm -import pandas.core.common as com @pytest.fixture @@ -23,48 +20,6 @@ def test_properties(self, interval): assert interval.right == 1 assert interval.mid == 0.5 - def test_repr(self, interval): - assert repr(interval) == "Interval(0, 1, closed='right')" - assert str(interval) == "(0, 1]" - - interval_left = Interval(0, 1, closed="left") - assert repr(interval_left) == "Interval(0, 1, closed='left')" - assert str(interval_left) == "[0, 1)" - - def test_contains(self, interval): - assert 0.5 in interval - assert 1 in interval - assert 0 not in interval - - interval_both = Interval(0, 1, "both") - assert 0 in interval_both - assert 1 in interval_both - - interval_neither = Interval(0, 1, closed="neither") - assert 0 not in interval_neither - assert 0.5 in interval_neither - assert 1 not in interval_neither - - def test_equal(self): - assert Interval(0, 1) == Interval(0, 1, closed="right") - assert Interval(0, 1) != Interval(0, 1, closed="left") - assert Interval(0, 1) != 0 - - def test_comparison(self): - msg = ( - "'<' not supported between instances of " - "'pandas._libs.interval.Interval' and 'int'" - ) - with pytest.raises(TypeError, match=msg): - Interval(0, 1) < 2 - - assert Interval(0, 1) < Interval(1, 2) - assert Interval(0, 1) < Interval(0, 2) - assert Interval(0, 1) < Interval(0.5, 1.5) - assert Interval(0, 1) <= Interval(0, 1) - assert Interval(0, 1) > Interval(-1, 2) - assert Interval(0, 1) >= Interval(0, 1) - def test_hash(self, interval): # should not raise hash(interval) @@ -130,150 +85,3 @@ def test_is_empty(self, left, right, closed): result = iv.is_empty expected = closed != "both" assert result is expected - - @pytest.mark.parametrize( - "left, right", - [ - ("a", "z"), - (("a", "b"), ("c", "d")), - (list("AB"), list("ab")), - (Interval(0, 1), Interval(1, 2)), - (Period("2018Q1", freq="Q"), Period("2018Q1", freq="Q")), - ], - ) - def test_construct_errors(self, left, right): - # GH 23013 - msg = "Only numeric, Timestamp and Timedelta endpoints are allowed" - with pytest.raises(ValueError, match=msg): - Interval(left, right) - - def test_math_add(self, closed): - interval = Interval(0, 1, closed=closed) - expected = Interval(1, 2, closed=closed) - - result = interval + 1 - assert result == expected - - result = 1 + interval - assert result == expected - - result = interval - result += 1 - assert result == expected - - msg = r"unsupported operand type\(s\) for \+" - with pytest.raises(TypeError, match=msg): - interval + interval - - with pytest.raises(TypeError, match=msg): - interval + "foo" - - def test_math_sub(self, closed): - interval = Interval(0, 1, closed=closed) - expected = Interval(-1, 0, closed=closed) - - result = interval - 1 - assert result == expected - - result = interval - result -= 1 - assert result == expected - - msg = r"unsupported operand type\(s\) for -" - with pytest.raises(TypeError, match=msg): - interval - interval - - with pytest.raises(TypeError, match=msg): - interval - "foo" - - def test_math_mult(self, closed): - interval = Interval(0, 1, closed=closed) - expected = Interval(0, 2, closed=closed) - - result = interval * 2 - assert result == expected - - result = 2 * interval - assert result == expected - - result = interval - result *= 2 - assert result == expected - - msg = r"unsupported operand type\(s\) for \*" - with pytest.raises(TypeError, match=msg): - interval * interval - - msg = r"can\'t multiply sequence by non-int" - with pytest.raises(TypeError, match=msg): - interval * "foo" - - def test_math_div(self, closed): - interval = Interval(0, 1, closed=closed) - expected = Interval(0, 0.5, closed=closed) - - result = interval / 2.0 - assert result == expected - - result = interval - result /= 2.0 - assert result == expected - - msg = r"unsupported operand type\(s\) for /" - with pytest.raises(TypeError, match=msg): - interval / interval - - with pytest.raises(TypeError, match=msg): - interval / "foo" - - def test_math_floordiv(self, closed): - interval = Interval(1, 2, closed=closed) - expected = Interval(0, 1, closed=closed) - - result = interval // 2 - assert result == expected - - result = interval - result //= 2 - assert result == expected - - msg = r"unsupported operand type\(s\) for //" - with pytest.raises(TypeError, match=msg): - interval // interval - - with pytest.raises(TypeError, match=msg): - interval // "foo" - - def test_constructor_errors(self): - msg = "invalid option for 'closed': foo" - with pytest.raises(ValueError, match=msg): - Interval(0, 1, closed="foo") - - msg = "left side of interval must be <= right side" - with pytest.raises(ValueError, match=msg): - Interval(1, 0) - - @pytest.mark.parametrize( - "tz_left, tz_right", [(None, "UTC"), ("UTC", None), ("UTC", "US/Eastern")] - ) - def test_constructor_errors_tz(self, tz_left, tz_right): - # GH 18538 - left = Timestamp("2017-01-01", tz=tz_left) - right = Timestamp("2017-01-02", tz=tz_right) - - if com.any_none(tz_left, tz_right): - error = TypeError - msg = "Cannot compare tz-naive and tz-aware timestamps" - else: - error = ValueError - msg = "left and right must have the same time zone" - with pytest.raises(error, match=msg): - Interval(left, right) - - def test_equality_comparison_broadcasts_over_array(self): - # https://github.com/pandas-dev/pandas/issues/35931 - interval = Interval(0, 1) - arr = np.array([interval, interval]) - result = interval == arr - expected = np.array([True, True]) - tm.assert_numpy_array_equal(result, expected) diff --git a/pandas/tests/scalar/interval/test_ops.py b/pandas/tests/scalar/interval/test_overlaps.py similarity index 54% rename from pandas/tests/scalar/interval/test_ops.py rename to pandas/tests/scalar/interval/test_overlaps.py index 92db6ac772830..7fcf59d7bb4af 100644 --- a/pandas/tests/scalar/interval/test_ops.py +++ b/pandas/tests/scalar/interval/test_overlaps.py @@ -1,4 +1,3 @@ -"""Tests for Interval-Interval operations, such as overlaps, contains, etc.""" import pytest from pandas import ( @@ -66,54 +65,3 @@ def test_overlaps_invalid_type(self, other): msg = f"`other` must be an Interval, got {type(other).__name__}" with pytest.raises(TypeError, match=msg): interval.overlaps(other) - - -class TestContains: - def test_contains_interval(self, inclusive_endpoints_fixture): - interval1 = Interval(0, 1, "both") - interval2 = Interval(0, 1, inclusive_endpoints_fixture) - assert interval1 in interval1 - assert interval2 in interval2 - assert interval2 in interval1 - assert interval1 not in interval2 or inclusive_endpoints_fixture == "both" - - def test_contains_infinite_length(self): - interval1 = Interval(0, 1, "both") - interval2 = Interval(float("-inf"), float("inf"), "neither") - assert interval1 in interval2 - assert interval2 not in interval1 - - def test_contains_zero_length(self): - interval1 = Interval(0, 1, "both") - interval2 = Interval(-1, -1, "both") - interval3 = Interval(0.5, 0.5, "both") - assert interval2 not in interval1 - assert interval3 in interval1 - assert interval2 not in interval3 and interval3 not in interval2 - assert interval1 not in interval2 and interval1 not in interval3 - - @pytest.mark.parametrize( - "type1", - [ - (0, 1), - (Timestamp(2000, 1, 1, 0), Timestamp(2000, 1, 1, 1)), - (Timedelta("0h"), Timedelta("1h")), - ], - ) - @pytest.mark.parametrize( - "type2", - [ - (0, 1), - (Timestamp(2000, 1, 1, 0), Timestamp(2000, 1, 1, 1)), - (Timedelta("0h"), Timedelta("1h")), - ], - ) - def test_contains_mixed_types(self, type1, type2): - interval1 = Interval(*type1) - interval2 = Interval(*type2) - if type1 == type2: - assert interval1 in interval2 - else: - msg = "^'<=' not supported between instances of" - with pytest.raises(TypeError, match=msg): - interval1 in interval2