From 155034ea88ab8a618165b976fd88b8b59347c02f Mon Sep 17 00:00:00 2001 From: "Uwe L. Korn" Date: Tue, 9 Jun 2020 11:28:02 +0200 Subject: [PATCH 1/3] BUG: Allow plain bools in ExtensionArray.equals --- pandas/core/arrays/base.py | 7 +++++- pandas/tests/extension/arrow/arrays.py | 28 +++++++++++++++++++++-- pandas/tests/extension/arrow/test_bool.py | 4 ++++ 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 79f0039a9df65..315942cd3f249 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -738,7 +738,12 @@ def equals(self, other: "ExtensionArray") -> bool: # boolean array with NA -> fill with False equal_values = equal_values.fillna(False) equal_na = self.isna() & other.isna() - return (equal_values | equal_na).all().item() + result = (equal_values | equal_na).all() + + if isinstance(result, np.bool_): + return result.item() + else: + return result def _values_for_factorize(self) -> Tuple[np.ndarray, Any]: """ diff --git a/pandas/tests/extension/arrow/arrays.py b/pandas/tests/extension/arrow/arrays.py index ffebc9f8b3359..29cfe1e0fe606 100644 --- a/pandas/tests/extension/arrow/arrays.py +++ b/pandas/tests/extension/arrow/arrays.py @@ -8,6 +8,7 @@ """ import copy import itertools +import operator from typing import Type import numpy as np @@ -106,6 +107,27 @@ def astype(self, dtype, copy=True): def dtype(self): return self._dtype + def _boolean_op(self, other, op): + if not isinstance(other, type(self)): + raise NotImplementedError() + + result = op(np.array(self._data), np.array(other._data)) + return ArrowBoolArray( + pa.chunked_array([pa.array(result, mask=pd.isna(self._data.to_pandas()))]) + ) + + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + + return self._boolean_op(other, operator.eq) + + def __and__(self, other): + return self._boolean_op(other, operator.and_) + + def __or__(self, other): + return self._boolean_op(other, operator.or_) + @property def nbytes(self): return sum( @@ -153,10 +175,12 @@ def _reduce(self, method, skipna=True, **kwargs): return op(**kwargs) def any(self, axis=0, out=None): - return self._data.to_pandas().any() + # Explicitly return a plain bool to reproduce GH-34660 + return bool(self._data.to_pandas().any()) def all(self, axis=0, out=None): - return self._data.to_pandas().all() + # Explicitly return a plain bool to reproduce GH-34660 + return bool(self._data.to_pandas().all()) class ArrowBoolArray(ArrowExtensionArray): diff --git a/pandas/tests/extension/arrow/test_bool.py b/pandas/tests/extension/arrow/test_bool.py index 48f1c34764313..ba1502b2b5437 100644 --- a/pandas/tests/extension/arrow/test_bool.py +++ b/pandas/tests/extension/arrow/test_bool.py @@ -29,6 +29,10 @@ def data_missing(): return ArrowBoolArray.from_scalars([None, True]) +def test_basic_equals(data): + assert pd.Series(data).equals(pd.Series(data)) + + class BaseArrowTests: pass From ffa60c44039d5075a51cf543a12f9022424dba59 Mon Sep 17 00:00:00 2001 From: "Uwe L. Korn" Date: Tue, 9 Jun 2020 13:37:41 +0200 Subject: [PATCH 2/3] Update pandas/tests/extension/arrow/test_bool.py Co-authored-by: Joris Van den Bossche --- pandas/tests/extension/arrow/test_bool.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pandas/tests/extension/arrow/test_bool.py b/pandas/tests/extension/arrow/test_bool.py index ba1502b2b5437..7841360e568ed 100644 --- a/pandas/tests/extension/arrow/test_bool.py +++ b/pandas/tests/extension/arrow/test_bool.py @@ -30,6 +30,7 @@ def data_missing(): def test_basic_equals(data): + # https://github.com/pandas-dev/pandas/issues/34660 assert pd.Series(data).equals(pd.Series(data)) From ac8d191205b8ae08739f01f9ea47a01c256066eb Mon Sep 17 00:00:00 2001 From: "Uwe L. Korn" Date: Tue, 9 Jun 2020 16:35:29 +0200 Subject: [PATCH 3/3] Review comment --- pandas/core/arrays/base.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 315942cd3f249..7f2c61ff7d955 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -738,12 +738,7 @@ def equals(self, other: "ExtensionArray") -> bool: # boolean array with NA -> fill with False equal_values = equal_values.fillna(False) equal_na = self.isna() & other.isna() - result = (equal_values | equal_na).all() - - if isinstance(result, np.bool_): - return result.item() - else: - return result + return bool((equal_values | equal_na).all()) def _values_for_factorize(self) -> Tuple[np.ndarray, Any]: """