From c3103e5748a4bfd57e32270342d84426656bf22e Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 15 Aug 2023 16:09:02 -0700 Subject: [PATCH 1/2] TST: use single-class pattern in test_masked.py --- pandas/tests/extension/test_masked.py | 81 ++------------------------- 1 file changed, 5 insertions(+), 76 deletions(-) diff --git a/pandas/tests/extension/test_masked.py b/pandas/tests/extension/test_masked.py index c4195be8ea121..bed406e902483 100644 --- a/pandas/tests/extension/test_masked.py +++ b/pandas/tests/extension/test_masked.py @@ -159,11 +159,7 @@ def data_for_grouping(dtype): return pd.array([b, b, na, na, a, a, b, c], dtype=dtype) -class TestDtype(base.BaseDtypeTests): - pass - - -class TestArithmeticOps(base.BaseArithmeticOpsTests): +class TestMaskedArrays(base.ExtensionTests): def _get_expected_exception(self, op_name, obj, other): try: dtype = tm.get_dtype(obj) @@ -179,12 +175,15 @@ def _get_expected_exception(self, op_name, obj, other): # exception message would include "numpy boolean subtract"" return TypeError return None - return super()._get_expected_exception(op_name, obj, other) + return None def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result): sdtype = tm.get_dtype(obj) expected = pointwise_result + if op_name in ("eq", "ne", "le", "ge", "lt", "gt"): + return expected.astype("boolean") + if sdtype.kind in "iu": if op_name in ("__rtruediv__", "__truediv__", "__div__"): expected = expected.fillna(np.nan).astype("Float64") @@ -219,11 +218,6 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result): expected = expected.astype(sdtype) return expected - series_scalar_exc = None - series_array_exc = None - frame_scalar_exc = None - divmod_exc = None - def test_divmod_series_array(self, data, data_for_twos, request): if data.dtype.kind == "b": mark = pytest.mark.xfail( @@ -234,49 +228,6 @@ def test_divmod_series_array(self, data, data_for_twos, request): request.node.add_marker(mark) super().test_divmod_series_array(data, data_for_twos) - -class TestComparisonOps(base.BaseComparisonOpsTests): - series_scalar_exc = None - series_array_exc = None - frame_scalar_exc = None - - def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result): - return pointwise_result.astype("boolean") - - -class TestInterface(base.BaseInterfaceTests): - pass - - -class TestConstructors(base.BaseConstructorsTests): - pass - - -class TestReshaping(base.BaseReshapingTests): - pass - - # for test_concat_mixed_dtypes test - # concat of an Integer and Int coerces to object dtype - # TODO(jreback) once integrated this would - - -class TestGetitem(base.BaseGetitemTests): - pass - - -class TestSetitem(base.BaseSetitemTests): - pass - - -class TestIndex(base.BaseIndexTests): - pass - - -class TestMissing(base.BaseMissingTests): - pass - - -class TestMethods(base.BaseMethodsTests): def test_combine_le(self, data_repeated): # TODO: patching self is a bad pattern here orig_data1, orig_data2 = data_repeated(2) @@ -287,16 +238,6 @@ def test_combine_le(self, data_repeated): self._combine_le_expected_dtype = object super().test_combine_le(data_repeated) - -class TestCasting(base.BaseCastingTests): - pass - - -class TestGroupby(base.BaseGroupbyTests): - pass - - -class TestReduce(base.BaseReduceTests): def _supports_reduction(self, obj, op_name: str) -> bool: if op_name in ["any", "all"] and tm.get_dtype(obj).kind != "b": pytest.skip(reason="Tested in tests/reductions/test_reductions.py") @@ -351,8 +292,6 @@ def _get_expected_reduction_dtype(self, arr, op_name: str): raise TypeError("not supposed to reach this") return cmp_dtype - -class TestAccumulation(base.BaseAccumulateTests): def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool: return True @@ -411,8 +350,6 @@ def check_accumulate(self, ser: pd.Series, op_name: str, skipna: bool): else: raise NotImplementedError(f"{op_name} not supported") - -class TestUnaryOps(base.BaseUnaryOpsTests): def test_invert(self, data, request): if data.dtype.kind == "f": mark = pytest.mark.xfail( @@ -423,13 +360,5 @@ def test_invert(self, data, request): super().test_invert(data) -class TestPrinting(base.BasePrintingTests): - pass - - -class TestParsing(base.BaseParsingTests): - pass - - class Test2DCompat(base.Dim2CompatTests): pass From dda944851035ab71e00bf408435d935ccef4ff40 Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 15 Aug 2023 16:34:13 -0700 Subject: [PATCH 2/2] TST: use one-class pattern in arrow extension tests --- pandas/tests/extension/test_arrow.py | 65 ++-------------------------- 1 file changed, 4 insertions(+), 61 deletions(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index dd1ff925adf5f..4c05049ddfcf5 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -265,7 +265,7 @@ def data_for_twos(data): # TODO: skip otherwise? -class TestBaseCasting(base.BaseCastingTests): +class TestArrowArray(base.ExtensionTests): def test_astype_str(self, data, request): pa_dtype = data.dtype.pyarrow_dtype if pa.types.is_binary(pa_dtype): @@ -276,8 +276,6 @@ def test_astype_str(self, data, request): ) super().test_astype_str(data) - -class TestConstructors(base.BaseConstructorsTests): def test_from_dtype(self, data, request): pa_dtype = data.dtype.pyarrow_dtype if pa.types.is_string(pa_dtype) or pa.types.is_decimal(pa_dtype): @@ -338,12 +336,6 @@ def test_from_sequence_of_strings_pa_array(self, data, request): result = type(data)._from_sequence_of_strings(pa_array, dtype=data.dtype) tm.assert_extension_array_equal(result, data) - -class TestGetitemTests(base.BaseGetitemTests): - pass - - -class TestBaseAccumulateTests(base.BaseAccumulateTests): def check_accumulate(self, ser, op_name, skipna): result = getattr(ser, op_name)(skipna=skipna) @@ -409,8 +401,6 @@ def test_accumulate_series(self, data, all_numeric_accumulations, skipna, reques self.check_accumulate(ser, op_name, skipna) - -class TestReduce(base.BaseReduceTests): def _supports_reduction(self, obj, op_name: str) -> bool: dtype = tm.get_dtype(obj) # error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype" has @@ -561,8 +551,6 @@ def test_median_not_approximate(self, typ): result = pd.Series([1, 2], dtype=f"{typ}[pyarrow]").median() assert result == 1.5 - -class TestBaseGroupby(base.BaseGroupbyTests): def test_in_numeric_groupby(self, data_for_grouping): dtype = data_for_grouping.dtype if is_string_dtype(dtype): @@ -583,8 +571,6 @@ def test_in_numeric_groupby(self, data_for_grouping): else: super().test_in_numeric_groupby(data_for_grouping) - -class TestBaseDtype(base.BaseDtypeTests): def test_construct_from_string_own_name(self, dtype, request): pa_dtype = dtype.pyarrow_dtype if pa.types.is_decimal(pa_dtype): @@ -651,20 +637,12 @@ def test_is_not_string_type(self, dtype): else: super().test_is_not_string_type(dtype) - -class TestBaseIndex(base.BaseIndexTests): - pass - - -class TestBaseInterface(base.BaseInterfaceTests): @pytest.mark.xfail( reason="GH 45419: pyarrow.ChunkedArray does not support views.", run=False ) def test_view(self, data): super().test_view(data) - -class TestBaseMissing(base.BaseMissingTests): def test_fillna_no_op_returns_copy(self, data): data = data[~data.isna()] @@ -677,28 +655,18 @@ def test_fillna_no_op_returns_copy(self, data): assert result is not data tm.assert_extension_array_equal(result, data) - -class TestBasePrinting(base.BasePrintingTests): - pass - - -class TestBaseReshaping(base.BaseReshapingTests): @pytest.mark.xfail( reason="GH 45419: pyarrow.ChunkedArray does not support views", run=False ) def test_transpose(self, data): super().test_transpose(data) - -class TestBaseSetitem(base.BaseSetitemTests): @pytest.mark.xfail( reason="GH 45419: pyarrow.ChunkedArray does not support views", run=False ) def test_setitem_preserves_views(self, data): super().test_setitem_preserves_views(data) - -class TestBaseParsing(base.BaseParsingTests): @pytest.mark.parametrize("dtype_backend", ["pyarrow", no_default]) @pytest.mark.parametrize("engine", ["c", "python"]) def test_EA_types(self, engine, data, dtype_backend, request): @@ -736,8 +704,6 @@ def test_EA_types(self, engine, data, dtype_backend, request): expected = df tm.assert_frame_equal(result, expected) - -class TestBaseUnaryOps(base.BaseUnaryOpsTests): def test_invert(self, data, request): pa_dtype = data.dtype.pyarrow_dtype if not pa.types.is_boolean(pa_dtype): @@ -749,8 +715,6 @@ def test_invert(self, data, request): ) super().test_invert(data) - -class TestBaseMethods(base.BaseMethodsTests): @pytest.mark.parametrize("periods", [1, -2]) def test_diff(self, data, periods, request): pa_dtype = data.dtype.pyarrow_dtype @@ -814,8 +778,6 @@ def test_argreduce_series( _combine_le_expected_dtype = "bool[pyarrow]" - -class TestBaseArithmeticOps(base.BaseArithmeticOpsTests): divmod_exc = NotImplementedError def get_op_from_name(self, op_name): @@ -838,6 +800,9 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result): # while ArrowExtensionArray maintains original type expected = pointwise_result + if op_name in ["eq", "ne", "lt", "le", "gt", "ge"]: + return pointwise_result.astype("boolean[pyarrow]") + was_frame = False if isinstance(expected, pd.DataFrame): was_frame = True @@ -1121,28 +1086,6 @@ def test_add_series_with_extension_array(self, data, request): ) super().test_add_series_with_extension_array(data) - -class TestBaseComparisonOps(base.BaseComparisonOpsTests): - def test_compare_array(self, data, comparison_op, na_value): - ser = pd.Series(data) - # pd.Series([ser.iloc[0]] * len(ser)) may not return ArrowExtensionArray - # since ser.iloc[0] is a python scalar - other = pd.Series(pd.array([ser.iloc[0]] * len(ser), dtype=data.dtype)) - if comparison_op.__name__ in ["eq", "ne"]: - # comparison should match point-wise comparisons - result = comparison_op(ser, other) - # Series.combine does not calculate the NA mask correctly - # when comparing over an array - assert result[8] is na_value - assert result[97] is na_value - expected = ser.combine(other, comparison_op) - expected[8] = na_value - expected[97] = na_value - tm.assert_series_equal(result, expected) - - else: - return super().test_compare_array(data, comparison_op) - def test_invalid_other_comp(self, data, comparison_op): # GH 48833 with pytest.raises(