diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 0a6bc97237ddd..53fc38a973110 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -367,7 +367,7 @@ def _cmp_method(self, other, op): pc_func = ARROW_CMP_FUNCS[op.__name__] if isinstance(other, ArrowStringArray): result = pc_func(self._data, other._data) - elif isinstance(other, np.ndarray): + elif isinstance(other, (np.ndarray, list)): result = pc_func(self._data, other) elif is_scalar(other): try: diff --git a/pandas/tests/arrays/string_/test_string.py b/pandas/tests/arrays/string_/test_string.py index 722222aab6d27..7c3a8c691b786 100644 --- a/pandas/tests/arrays/string_/test_string.py +++ b/pandas/tests/arrays/string_/test_string.py @@ -217,15 +217,18 @@ def test_comparison_methods_scalar_pd_na(comparison_op, dtype): tm.assert_extension_array_equal(result, expected) -def test_comparison_methods_scalar_not_string(comparison_op, dtype, request): +def test_comparison_methods_scalar_not_string(comparison_op, dtype): op_name = f"__{comparison_op.__name__}__" - if op_name not in ["__eq__", "__ne__"]: - reason = "comparison op not supported between instances of 'str' and 'int'" - mark = pytest.mark.xfail(raises=TypeError, reason=reason) - request.node.add_marker(mark) a = pd.array(["a", None, "c"], dtype=dtype) other = 42 + + if op_name not in ["__eq__", "__ne__"]: + with pytest.raises(TypeError, match="not supported between"): + getattr(a, op_name)(other) + + return + result = getattr(a, op_name)(other) expected_data = {"__eq__": [False, None, False], "__ne__": [True, None, True]}[ op_name @@ -234,12 +237,7 @@ def test_comparison_methods_scalar_not_string(comparison_op, dtype, request): tm.assert_extension_array_equal(result, expected) -def test_comparison_methods_array(comparison_op, dtype, request): - if dtype.storage == "pyarrow": - mark = pytest.mark.xfail( - raises=AssertionError, reason="left is not an ExtensionArray" - ) - request.node.add_marker(mark) +def test_comparison_methods_array(comparison_op, dtype): op_name = f"__{comparison_op.__name__}__" @@ -340,6 +338,17 @@ def test_reduce(skipna, dtype): assert result == "abc" +@pytest.mark.parametrize("skipna", [True, False]) +@pytest.mark.xfail(reason="Not implemented StringArray.sum") +def test_reduce_missing(skipna, dtype): + arr = pd.Series([None, "a", None, "b", "c", None], dtype=dtype) + result = arr.sum(skipna=skipna) + if skipna: + assert result == "abc" + else: + assert pd.isna(result) + + @pytest.mark.parametrize("method", ["min", "max"]) @pytest.mark.parametrize("skipna", [True, False]) def test_min_max(method, skipna, dtype, request): @@ -374,17 +383,6 @@ def test_min_max_numpy(method, box, dtype, request): assert result == expected -@pytest.mark.parametrize("skipna", [True, False]) -@pytest.mark.xfail(reason="Not implemented StringArray.sum") -def test_reduce_missing(skipna, dtype): - arr = pd.Series([None, "a", None, "b", "c", None], dtype=dtype) - result = arr.sum(skipna=skipna) - if skipna: - assert result == "abc" - else: - assert pd.isna(result) - - def test_fillna_args(dtype, request): # GH 37987 diff --git a/pandas/tests/extension/arrow/arrays.py b/pandas/tests/extension/arrow/arrays.py index 1a330bb584ba5..fad28c1896ad0 100644 --- a/pandas/tests/extension/arrow/arrays.py +++ b/pandas/tests/extension/arrow/arrays.py @@ -26,6 +26,7 @@ ) from pandas.api.types import is_scalar from pandas.core.arraylike import OpsMixin +from pandas.core.construction import extract_array @register_extension_dtype @@ -77,6 +78,16 @@ class ArrowExtensionArray(OpsMixin, ExtensionArray): @classmethod def from_scalars(cls, values): + if isinstance(values, cls): + # in particular for empty cases the pa.array(np.asarray(...)) + # does not round-trip + return cls(values._data) + + elif not len(values): + if isinstance(values, list): + dtype = bool if cls is ArrowBoolArray else str + values = np.array([], dtype=dtype) + arr = pa.chunked_array([pa.array(np.asarray(values))]) return cls(arr) @@ -92,6 +103,14 @@ def _from_sequence(cls, scalars, dtype=None, copy=False): def __repr__(self): return f"{type(self).__name__}({repr(self._data)})" + def __contains__(self, obj) -> bool: + if obj is None or obj is self.dtype.na_value: + # None -> EA.__contains__ only checks for self._dtype.na_value, not + # any compatible NA value. + # self.dtype.na_value -> isn't recognized by pd.isna + return bool(self.isna().any()) + return bool(super().__contains__(obj)) + def __getitem__(self, item): if is_scalar(item): return self._data.to_pandas()[item] @@ -125,7 +144,8 @@ def _logical_method(self, other, op): def __eq__(self, other): if not isinstance(other, type(self)): - return False + # TODO: use some pyarrow function here? + return np.asarray(self).__eq__(other) return self._logical_method(other, operator.eq) @@ -144,6 +164,7 @@ def isna(self): def take(self, indices, allow_fill=False, fill_value=None): data = self._data.to_pandas() + data = extract_array(data, extract_numpy=True) if allow_fill and fill_value is None: fill_value = self.dtype.na_value diff --git a/pandas/tests/extension/arrow/test_bool.py b/pandas/tests/extension/arrow/test_bool.py index 9564239f119f3..a73684868e3ae 100644 --- a/pandas/tests/extension/arrow/test_bool.py +++ b/pandas/tests/extension/arrow/test_bool.py @@ -54,8 +54,8 @@ def test_view(self, data): data.view() @pytest.mark.xfail( - raises=AttributeError, - reason="__eq__ incorrectly returns bool instead of ndarray[bool]", + raises=AssertionError, + reason="Doesn't recognize data._na_value as NA", ) def test_contains(self, data, data_missing): super().test_contains(data, data_missing) @@ -77,7 +77,7 @@ def test_series_constructor_scalar_na_with_index(self, dtype, na_value): # pyarrow.lib.ArrowInvalid: only handle 1-dimensional arrays super().test_series_constructor_scalar_na_with_index(dtype, na_value) - @pytest.mark.xfail(reason="raises AssertionError") + @pytest.mark.xfail(reason="ufunc 'invert' not supported for the input types") def test_construct_empty_dataframe(self, dtype): super().test_construct_empty_dataframe(dtype) diff --git a/pandas/tests/extension/test_string.py b/pandas/tests/extension/test_string.py index 5049116a9320e..d9351add0fe6d 100644 --- a/pandas/tests/extension/test_string.py +++ b/pandas/tests/extension/test_string.py @@ -146,9 +146,9 @@ def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna): if op_name in ["min", "max"]: return None - s = pd.Series(data) + ser = pd.Series(data) with pytest.raises(TypeError): - getattr(s, op_name)(skipna=skipna) + getattr(ser, op_name)(skipna=skipna) class TestMethods(base.BaseMethodsTests): @@ -166,15 +166,15 @@ class TestCasting(base.BaseCastingTests): class TestComparisonOps(base.BaseComparisonOpsTests): - def _compare_other(self, s, data, op, other): + def _compare_other(self, ser, data, op, other): op_name = f"__{op.__name__}__" - result = getattr(s, op_name)(other) - expected = getattr(s.astype(object), op_name)(other).astype("boolean") + result = getattr(ser, op_name)(other) + expected = getattr(ser.astype(object), op_name)(other).astype("boolean") self.assert_series_equal(result, expected) def test_compare_scalar(self, data, comparison_op): - s = pd.Series(data) - self._compare_other(s, data, comparison_op, "abc") + ser = pd.Series(data) + self._compare_other(ser, data, comparison_op, "abc") class TestParsing(base.BaseParsingTests): diff --git a/pandas/tests/strings/test_string_array.py b/pandas/tests/strings/test_string_array.py index 0de93b479e43e..90c26a747abdd 100644 --- a/pandas/tests/strings/test_string_array.py +++ b/pandas/tests/strings/test_string_array.py @@ -12,13 +12,16 @@ def test_string_array(nullable_string_dtype, any_string_method): method_name, args, kwargs = any_string_method - if method_name == "decode": - pytest.skip("decode requires bytes.") data = ["a", "bb", np.nan, "ccc"] a = Series(data, dtype=object) b = Series(data, dtype=nullable_string_dtype) + if method_name == "decode": + with pytest.raises(TypeError, match="a bytes-like object is required"): + getattr(b.str, method_name)(*args, **kwargs) + return + expected = getattr(a.str, method_name)(*args, **kwargs) result = getattr(b.str, method_name)(*args, **kwargs)