From 1cb3709335ee24174395cb8b5719a5daf2fc8c58 Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 8 Dec 2021 09:29:37 -0800 Subject: [PATCH 1/2] BUG: IntervalArray.__cmp__(pd.NA) GH#31882 --- pandas/core/arrays/interval.py | 10 ++++++- pandas/tests/arithmetic/test_interval.py | 34 ++++++++++++------------ 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index 19c71f59315aa..ea6673fdaf0cf 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -685,6 +685,13 @@ def _cmp_method(self, other, op): other = pd_array(other) elif not isinstance(other, Interval): # non-interval scalar -> no matches + if other is NA: + # GH#31882 + from pandas.core.arrays import BooleanArray + + arr = np.empty(self.shape, dtype=bool) + mask = np.ones(self.shape, dtype=bool) + return BooleanArray(arr, mask) return invalid_comparison(self, other, op) # determine the dtype of the elements we want to compare @@ -743,7 +750,8 @@ def _cmp_method(self, other, op): if obj is NA: # comparison with np.nan returns NA # github.com/pandas-dev/pandas/pull/37124#discussion_r509095092 - result[i] = op is operator.ne + result = result.astype(object) + result[i] = NA else: raise return result diff --git a/pandas/tests/arithmetic/test_interval.py b/pandas/tests/arithmetic/test_interval.py index 88b26dcc4d707..bea3d28e241fe 100644 --- a/pandas/tests/arithmetic/test_interval.py +++ b/pandas/tests/arithmetic/test_interval.py @@ -20,7 +20,10 @@ timedelta_range, ) import pandas._testing as tm -from pandas.core.arrays import IntervalArray +from pandas.core.arrays import ( + BooleanArray, + IntervalArray, +) @pytest.fixture( @@ -129,18 +132,20 @@ def test_compare_scalar_interval_mixed_closed(self, op, closed, other_closed): expected = self.elementwise_comparison(op, interval_array, other) tm.assert_numpy_array_equal(result, expected) - def test_compare_scalar_na(self, op, interval_array, nulls_fixture, request): + def test_compare_scalar_na(self, op, interval_array, nulls_fixture): result = op(interval_array, nulls_fixture) - expected = self.elementwise_comparison(op, interval_array, nulls_fixture) - if nulls_fixture is pd.NA and interval_array.dtype.subtype != "int64": - mark = pytest.mark.xfail( - raises=AssertionError, - reason="broken for non-integer IntervalArray; see GH 31882", - ) - request.node.add_marker(mark) + if nulls_fixture is pd.NA: + # GH#31882 + exp = np.ones(interval_array.shape, dtype=bool) + expected = BooleanArray(exp, exp) + else: + expected = self.elementwise_comparison(op, interval_array, nulls_fixture) - tm.assert_numpy_array_equal(result, expected) + tm.assert_equal(result, expected) + + rev = op(nulls_fixture, interval_array) + tm.assert_equal(rev, expected) @pytest.mark.parametrize( "other", @@ -214,17 +219,12 @@ def test_compare_list_like_object(self, op, interval_array, other): expected = self.elementwise_comparison(op, interval_array, other) tm.assert_numpy_array_equal(result, expected) - def test_compare_list_like_nan(self, op, interval_array, nulls_fixture, request): + def test_compare_list_like_nan(self, op, interval_array, nulls_fixture): other = [nulls_fixture] * 4 result = op(interval_array, other) expected = self.elementwise_comparison(op, interval_array, other) - if nulls_fixture is pd.NA and interval_array.dtype.subtype != "i8": - reason = "broken for non-integer IntervalArray; see GH 31882" - mark = pytest.mark.xfail(raises=AssertionError, reason=reason) - request.node.add_marker(mark) - - tm.assert_numpy_array_equal(result, expected) + tm.assert_equal(result, expected) @pytest.mark.parametrize( "other", From 5c477878dfbc36031e9ceb1a35fb56961a136591 Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 9 Dec 2021 13:15:08 -0800 Subject: [PATCH 2/2] TST: parametrize over box_with_array --- pandas/_testing/__init__.py | 5 +++++ pandas/tests/arithmetic/test_interval.py | 24 +++++++++++++++++++++--- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/pandas/_testing/__init__.py b/pandas/_testing/__init__.py index 6248154422252..16094bd88d66f 100644 --- a/pandas/_testing/__init__.py +++ b/pandas/_testing/__init__.py @@ -38,6 +38,7 @@ is_unsigned_integer_dtype, pandas_dtype, ) +from pandas.core.dtypes.dtypes import IntervalDtype import pandas as pd from pandas import ( @@ -282,6 +283,10 @@ def to_array(obj): return DatetimeArray._from_sequence(obj) elif is_timedelta64_dtype(dtype): return TimedeltaArray._from_sequence(obj) + elif isinstance(obj, pd.core.arrays.BooleanArray): + return obj + elif isinstance(dtype, IntervalDtype): + return pd.core.arrays.IntervalArray(obj) else: return np.array(obj) diff --git a/pandas/tests/arithmetic/test_interval.py b/pandas/tests/arithmetic/test_interval.py index bea3d28e241fe..88e3dca62d9e0 100644 --- a/pandas/tests/arithmetic/test_interval.py +++ b/pandas/tests/arithmetic/test_interval.py @@ -24,6 +24,7 @@ BooleanArray, IntervalArray, ) +from pandas.tests.arithmetic.common import get_upcast_box @pytest.fixture( @@ -132,8 +133,20 @@ def test_compare_scalar_interval_mixed_closed(self, op, closed, other_closed): expected = self.elementwise_comparison(op, interval_array, other) tm.assert_numpy_array_equal(result, expected) - def test_compare_scalar_na(self, op, interval_array, nulls_fixture): - result = op(interval_array, nulls_fixture) + def test_compare_scalar_na( + self, op, interval_array, nulls_fixture, box_with_array, request + ): + box = box_with_array + + if box is pd.DataFrame: + if interval_array.dtype.subtype.kind not in "iuf": + mark = pytest.mark.xfail( + reason="raises on DataFrame.transpose (would be fixed by EA2D)" + ) + request.node.add_marker(mark) + + obj = tm.box_expected(interval_array, box) + result = op(obj, nulls_fixture) if nulls_fixture is pd.NA: # GH#31882 @@ -142,9 +155,14 @@ def test_compare_scalar_na(self, op, interval_array, nulls_fixture): else: expected = self.elementwise_comparison(op, interval_array, nulls_fixture) + if not (box is Index and nulls_fixture is pd.NA): + # don't cast expected from BooleanArray to ndarray[object] + xbox = get_upcast_box(obj, nulls_fixture, True) + expected = tm.box_expected(expected, xbox) + tm.assert_equal(result, expected) - rev = op(nulls_fixture, interval_array) + rev = op(nulls_fixture, obj) tm.assert_equal(rev, expected) @pytest.mark.parametrize(