From 379feea3d95fc9259e29948da644008949a6336a Mon Sep 17 00:00:00 2001 From: Will Ayd Date: Fri, 8 Nov 2024 15:24:27 -0500 Subject: [PATCH 1/2] TST (string dtype): fix sql xfails with using_infer_string --- pandas/core/dtypes/cast.py | 2 ++ pandas/core/internals/construction.py | 5 +++-- pandas/io/sql.py | 21 +++++++++++++++++++-- pandas/tests/io/test_sql.py | 22 ++++++++++++---------- 4 files changed, 36 insertions(+), 14 deletions(-) diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 830b84852c704..137a49c4487f6 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -1162,6 +1162,7 @@ def convert_dtypes( def maybe_infer_to_datetimelike( value: npt.NDArray[np.object_], + convert_to_nullable_dtype: bool = False, ) -> np.ndarray | DatetimeArray | TimedeltaArray | PeriodArray | IntervalArray: """ we might have a array (or single object) that is datetime like, @@ -1199,6 +1200,7 @@ def maybe_infer_to_datetimelike( # numpy would have done it for us. convert_numeric=False, convert_non_numeric=True, + convert_to_nullable_dtype=convert_to_nullable_dtype, dtype_if_all_nat=np.dtype("M8[s]"), ) diff --git a/pandas/core/internals/construction.py b/pandas/core/internals/construction.py index 0812ba5e6def4..f357a53a10be8 100644 --- a/pandas/core/internals/construction.py +++ b/pandas/core/internals/construction.py @@ -966,8 +966,9 @@ def convert(arr): if dtype is None: if arr.dtype == np.dtype("O"): # i.e. maybe_convert_objects didn't convert - arr = maybe_infer_to_datetimelike(arr) - if dtype_backend != "numpy" and arr.dtype == np.dtype("O"): + convert_to_nullable_dtype = dtype_backend != "numpy" + arr = maybe_infer_to_datetimelike(arr, convert_to_nullable_dtype) + if convert_to_nullable_dtype and arr.dtype == np.dtype("O"): new_dtype = StringDtype() arr_cls = new_dtype.construct_array_type() arr = arr_cls._from_sequence(arr, dtype=new_dtype) diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 9aff5600cf49b..125ca51a456d8 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -45,6 +45,8 @@ from pandas.core.dtypes.common import ( is_dict_like, is_list_like, + is_object_dtype, + is_string_dtype, ) from pandas.core.dtypes.dtypes import ( ArrowDtype, @@ -58,6 +60,7 @@ Series, ) from pandas.core.arrays import ArrowExtensionArray +from pandas.core.arrays.string_ import StringDtype from pandas.core.base import PandasObject import pandas.core.common as com from pandas.core.common import maybe_make_list @@ -1316,7 +1319,12 @@ def _harmonize_columns( elif dtype_backend == "numpy" and col_type is float: # floats support NA, can always convert! self.frame[col_name] = df_col.astype(col_type) - + elif ( + using_string_dtype() + and is_string_dtype(col_type) + and is_object_dtype(self.frame[col_name]) + ): + self.frame[col_name] = df_col.astype(col_type) elif dtype_backend == "numpy" and len(df_col) == df_col.count(): # No NA values, can convert ints and bools if col_type is np.dtype("int64") or col_type is bool: @@ -1403,6 +1411,7 @@ def _get_dtype(self, sqltype): DateTime, Float, Integer, + String, ) if isinstance(sqltype, Float): @@ -1422,6 +1431,10 @@ def _get_dtype(self, sqltype): return date elif isinstance(sqltype, Boolean): return bool + elif isinstance(sqltype, String): + if using_string_dtype(): + return StringDtype(na_value=np.nan) + return object @@ -2205,7 +2218,7 @@ def read_table( elif using_string_dtype(): from pandas.io._util import arrow_string_types_mapper - arrow_string_types_mapper() + mapping = arrow_string_types_mapper() else: mapping = None @@ -2286,6 +2299,10 @@ def read_query( from pandas.io._util import _arrow_dtype_mapping mapping = _arrow_dtype_mapping().get + elif using_string_dtype(): + from pandas.io._util import arrow_string_types_mapper + + mapping = arrow_string_types_mapper() else: mapping = None diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index beca8dea9407d..8be7cbd82b0fd 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -60,7 +60,6 @@ pytest.mark.filterwarnings( "ignore:Passing a BlockManager to DataFrame:DeprecationWarning" ), - pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False), ] @@ -685,6 +684,7 @@ def postgresql_psycopg2_conn(postgresql_psycopg2_engine): @pytest.fixture def postgresql_adbc_conn(): + pytest.importorskip("pyarrow") pytest.importorskip("adbc_driver_postgresql") from adbc_driver_postgresql import dbapi @@ -817,6 +817,7 @@ def sqlite_conn_types(sqlite_engine_types): @pytest.fixture def sqlite_adbc_conn(): + pytest.importorskip("pyarrow") pytest.importorskip("adbc_driver_sqlite") from adbc_driver_sqlite import dbapi @@ -986,13 +987,13 @@ def test_dataframe_to_sql(conn, test_frame1, request): @pytest.mark.parametrize("conn", all_connectable) def test_dataframe_to_sql_empty(conn, test_frame1, request): - if conn == "postgresql_adbc_conn": + if conn == "postgresql_adbc_conn" and not using_string_dtype(): request.node.add_marker( pytest.mark.xfail( - reason="postgres ADBC driver cannot insert index with null type", - strict=True, + reason="postgres ADBC driver < 1.2 cannot insert index with null type", ) ) + # GH 51086 if conn is sqlite_engine conn = request.getfixturevalue(conn) empty_df = test_frame1.iloc[:0] @@ -3557,7 +3558,8 @@ def test_read_sql_dtype_backend( result = getattr(pd, func)( f"Select * from {table}", conn, dtype_backend=dtype_backend ) - expected = dtype_backend_expected(string_storage, dtype_backend, conn_name) + expected = dtype_backend_expected(string_storage, dtype_backend, conn_name) + tm.assert_frame_equal(result, expected) if "adbc" in conn_name: @@ -3607,7 +3609,7 @@ def test_read_sql_dtype_backend_table( with pd.option_context("mode.string_storage", string_storage): result = getattr(pd, func)(table, conn, dtype_backend=dtype_backend) - expected = dtype_backend_expected(string_storage, dtype_backend, conn_name) + expected = dtype_backend_expected(string_storage, dtype_backend, conn_name) tm.assert_frame_equal(result, expected) if "adbc" in conn_name: @@ -4123,7 +4125,7 @@ def tquery(query, con=None): def test_xsqlite_basic(sqlite_buildin): frame = DataFrame( np.random.default_rng(2).standard_normal((10, 4)), - columns=Index(list("ABCD"), dtype=object), + columns=Index(list("ABCD")), index=date_range("2000-01-01", periods=10, freq="B"), ) assert sql.to_sql(frame, name="test_table", con=sqlite_buildin, index=False) == 10 @@ -4150,7 +4152,7 @@ def test_xsqlite_basic(sqlite_buildin): def test_xsqlite_write_row_by_row(sqlite_buildin): frame = DataFrame( np.random.default_rng(2).standard_normal((10, 4)), - columns=Index(list("ABCD"), dtype=object), + columns=Index(list("ABCD")), index=date_range("2000-01-01", periods=10, freq="B"), ) frame.iloc[0, 0] = np.nan @@ -4173,7 +4175,7 @@ def test_xsqlite_write_row_by_row(sqlite_buildin): def test_xsqlite_execute(sqlite_buildin): frame = DataFrame( np.random.default_rng(2).standard_normal((10, 4)), - columns=Index(list("ABCD"), dtype=object), + columns=Index(list("ABCD")), index=date_range("2000-01-01", periods=10, freq="B"), ) create_sql = sql.get_schema(frame, "test") @@ -4194,7 +4196,7 @@ def test_xsqlite_execute(sqlite_buildin): def test_xsqlite_schema(sqlite_buildin): frame = DataFrame( np.random.default_rng(2).standard_normal((10, 4)), - columns=Index(list("ABCD"), dtype=object), + columns=Index(list("ABCD")), index=date_range("2000-01-01", periods=10, freq="B"), ) create_sql = sql.get_schema(frame, "test") From 9eb9232320daf88d0b1286da2d38c6bfc7a59559 Mon Sep 17 00:00:00 2001 From: Will Ayd Date: Sat, 9 Nov 2024 14:04:57 -0500 Subject: [PATCH 2/2] Add single_cpu mark --- pandas/tests/io/test_sql.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index 8be7cbd82b0fd..96d63d3fe25e5 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -60,6 +60,7 @@ pytest.mark.filterwarnings( "ignore:Passing a BlockManager to DataFrame:DeprecationWarning" ), + pytest.mark.single_cpu, ]