diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 79f0039a9df65..7f2c61ff7d955 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -738,7 +738,7 @@ def equals(self, other: "ExtensionArray") -> bool: # boolean array with NA -> fill with False equal_values = equal_values.fillna(False) equal_na = self.isna() & other.isna() - return (equal_values | equal_na).all().item() + return bool((equal_values | equal_na).all()) def _values_for_factorize(self) -> Tuple[np.ndarray, Any]: """ diff --git a/pandas/tests/extension/arrow/arrays.py b/pandas/tests/extension/arrow/arrays.py index ffebc9f8b3359..29cfe1e0fe606 100644 --- a/pandas/tests/extension/arrow/arrays.py +++ b/pandas/tests/extension/arrow/arrays.py @@ -8,6 +8,7 @@ """ import copy import itertools +import operator from typing import Type import numpy as np @@ -106,6 +107,27 @@ def astype(self, dtype, copy=True): def dtype(self): return self._dtype + def _boolean_op(self, other, op): + if not isinstance(other, type(self)): + raise NotImplementedError() + + result = op(np.array(self._data), np.array(other._data)) + return ArrowBoolArray( + pa.chunked_array([pa.array(result, mask=pd.isna(self._data.to_pandas()))]) + ) + + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + + return self._boolean_op(other, operator.eq) + + def __and__(self, other): + return self._boolean_op(other, operator.and_) + + def __or__(self, other): + return self._boolean_op(other, operator.or_) + @property def nbytes(self): return sum( @@ -153,10 +175,12 @@ def _reduce(self, method, skipna=True, **kwargs): return op(**kwargs) def any(self, axis=0, out=None): - return self._data.to_pandas().any() + # Explicitly return a plain bool to reproduce GH-34660 + return bool(self._data.to_pandas().any()) def all(self, axis=0, out=None): - return self._data.to_pandas().all() + # Explicitly return a plain bool to reproduce GH-34660 + return bool(self._data.to_pandas().all()) class ArrowBoolArray(ArrowExtensionArray): diff --git a/pandas/tests/extension/arrow/test_bool.py b/pandas/tests/extension/arrow/test_bool.py index 48f1c34764313..7841360e568ed 100644 --- a/pandas/tests/extension/arrow/test_bool.py +++ b/pandas/tests/extension/arrow/test_bool.py @@ -29,6 +29,11 @@ def data_missing(): return ArrowBoolArray.from_scalars([None, True]) +def test_basic_equals(data): + # https://github.com/pandas-dev/pandas/issues/34660 + assert pd.Series(data).equals(pd.Series(data)) + + class BaseArrowTests: pass