diff --git a/pandas/core/arrays/sparse/array.py b/pandas/core/arrays/sparse/array.py index d4ec641794fc2..5a66bf522215a 100644 --- a/pandas/core/arrays/sparse/array.py +++ b/pandas/core/arrays/sparse/array.py @@ -40,6 +40,7 @@ from pandas.core.dtypes.missing import isna, na_value_for_dtype, notna import pandas.core.algorithms as algos +from pandas.core.arraylike import OpsMixin from pandas.core.arrays import ExtensionArray, ExtensionOpsMixin from pandas.core.arrays.sparse.dtype import SparseDtype from pandas.core.base import PandasObject @@ -195,7 +196,7 @@ def _wrap_result(name, data, sparse_index, fill_value, dtype=None): ) -class SparseArray(PandasObject, ExtensionArray, ExtensionOpsMixin): +class SparseArray(OpsMixin, PandasObject, ExtensionArray, ExtensionOpsMixin): """ An ExtensionArray for storing sparse data. @@ -762,8 +763,6 @@ def value_counts(self, dropna=True): # -------- def __getitem__(self, key): - # avoid mypy issues when importing at the top-level - from pandas.core.indexing import check_bool_indexer if isinstance(key, tuple): if len(key) > 1: @@ -796,7 +795,6 @@ def __getitem__(self, key): key = check_array_indexer(self, key) if com.is_bool_indexer(key): - key = check_bool_indexer(self, key) return self.take(np.arange(len(key), dtype=np.int32)[key]) elif hasattr(key, "__len__"): @@ -1390,17 +1388,6 @@ def __abs__(self): # Ops # ------------------------------------------------------------------------ - @classmethod - def _create_unary_method(cls, op) -> Callable[["SparseArray"], "SparseArray"]: - def sparse_unary_method(self) -> "SparseArray": - fill_value = op(np.array(self.fill_value)).item() - values = op(self.sp_values) - dtype = SparseDtype(values.dtype, fill_value) - return cls._simple_new(values, self.sp_index, dtype) - - name = f"__{op.__name__}__" - return compat.set_function_name(sparse_unary_method, name, cls) - @classmethod def _create_arithmetic_method(cls, op): op_name = op.__name__ @@ -1444,56 +1431,48 @@ def sparse_arithmetic_method(self, other): name = f"__{op.__name__}__" return compat.set_function_name(sparse_arithmetic_method, name, cls) - @classmethod - def _create_comparison_method(cls, op): - op_name = op.__name__ - if op_name in {"and_", "or_"}: - op_name = op_name[:-1] + def _cmp_method(self, other, op) -> "SparseArray": + if not is_scalar(other) and not isinstance(other, type(self)): + # convert list-like to ndarray + other = np.asarray(other) - @unpack_zerodim_and_defer(op_name) - def cmp_method(self, other): - - if not is_scalar(other) and not isinstance(other, type(self)): - # convert list-like to ndarray - other = np.asarray(other) + if isinstance(other, np.ndarray): + # TODO: make this more flexible than just ndarray... + if len(self) != len(other): + raise AssertionError(f"length mismatch: {len(self)} vs. {len(other)}") + other = SparseArray(other, fill_value=self.fill_value) - if isinstance(other, np.ndarray): - # TODO: make this more flexible than just ndarray... - if len(self) != len(other): - raise AssertionError( - f"length mismatch: {len(self)} vs. {len(other)}" - ) - other = SparseArray(other, fill_value=self.fill_value) + if isinstance(other, SparseArray): + op_name = op.__name__.strip("_") + return _sparse_array_op(self, other, op, op_name) + else: + with np.errstate(all="ignore"): + fill_value = op(self.fill_value, other) + result = op(self.sp_values, other) + + return type(self)( + result, + sparse_index=self.sp_index, + fill_value=fill_value, + dtype=np.bool_, + ) - if isinstance(other, SparseArray): - return _sparse_array_op(self, other, op, op_name) - else: - with np.errstate(all="ignore"): - fill_value = op(self.fill_value, other) - result = op(self.sp_values, other) + _logical_method = _cmp_method - return type(self)( - result, - sparse_index=self.sp_index, - fill_value=fill_value, - dtype=np.bool_, - ) + def _unary_method(self, op) -> "SparseArray": + fill_value = op(np.array(self.fill_value)).item() + values = op(self.sp_values) + dtype = SparseDtype(values.dtype, fill_value) + return type(self)._simple_new(values, self.sp_index, dtype) - name = f"__{op.__name__}__" - return compat.set_function_name(cmp_method, name, cls) + def __pos__(self) -> "SparseArray": + return self._unary_method(operator.pos) - @classmethod - def _add_unary_ops(cls): - cls.__pos__ = cls._create_unary_method(operator.pos) - cls.__neg__ = cls._create_unary_method(operator.neg) - cls.__invert__ = cls._create_unary_method(operator.invert) + def __neg__(self) -> "SparseArray": + return self._unary_method(operator.neg) - @classmethod - def _add_comparison_ops(cls): - cls.__and__ = cls._create_comparison_method(operator.and_) - cls.__or__ = cls._create_comparison_method(operator.or_) - cls.__xor__ = cls._create_arithmetic_method(operator.xor) - super()._add_comparison_ops() + def __invert__(self) -> "SparseArray": + return self._unary_method(operator.invert) # ---------- # Formatting @@ -1511,8 +1490,6 @@ def _formatter(self, boxed=False): SparseArray._add_arithmetic_ops() -SparseArray._add_comparison_ops() -SparseArray._add_unary_ops() def make_sparse(arr: np.ndarray, kind="block", fill_value=None, dtype=None): diff --git a/setup.cfg b/setup.cfg index 8d3d79789a252..554c8b30641f0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -142,9 +142,6 @@ check_untyped_defs=False [mypy-pandas.core.arrays.datetimelike] check_untyped_defs=False -[mypy-pandas.core.arrays.sparse.array] -check_untyped_defs=False - [mypy-pandas.core.arrays.string_] check_untyped_defs=False