Skip to content

[backport 2.3.x] String dtype: enable in SQL IO + resolve all xfails (#60255) #60315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion pandas/_libs/lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2741,7 +2741,13 @@ def maybe_convert_objects(ndarray[object] objects,
seen.object_ = True

elif seen.str_:
if using_string_dtype() and is_string_array(objects, skipna=True):
if convert_to_nullable_dtype and is_string_array(objects, skipna=True):
from pandas.core.arrays.string_ import StringDtype

dtype = StringDtype()
return dtype.construct_array_type()._from_sequence(objects, dtype=dtype)

elif using_string_dtype() and is_string_array(objects, skipna=True):
from pandas.core.arrays.string_ import StringDtype

dtype = StringDtype(na_value=np.nan)
Expand Down
2 changes: 2 additions & 0 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,6 +1163,7 @@ def convert_dtypes(

def maybe_infer_to_datetimelike(
value: npt.NDArray[np.object_],
convert_to_nullable_dtype: bool = False,
) -> np.ndarray | DatetimeArray | TimedeltaArray | PeriodArray | IntervalArray:
"""
we might have a array (or single object) that is datetime like,
Expand Down Expand Up @@ -1200,6 +1201,7 @@ def maybe_infer_to_datetimelike(
# numpy would have done it for us.
convert_numeric=False,
convert_non_numeric=True,
convert_to_nullable_dtype=convert_to_nullable_dtype,
dtype_if_all_nat=np.dtype("M8[ns]"),
)

Expand Down
5 changes: 3 additions & 2 deletions pandas/core/internals/construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,8 +1042,9 @@ def convert(arr):
if dtype is None:
if arr.dtype == np.dtype("O"):
# i.e. maybe_convert_objects didn't convert
arr = maybe_infer_to_datetimelike(arr)
if dtype_backend != "numpy" and arr.dtype == np.dtype("O"):
convert_to_nullable_dtype = dtype_backend != "numpy"
arr = maybe_infer_to_datetimelike(arr, convert_to_nullable_dtype)
if convert_to_nullable_dtype and arr.dtype == np.dtype("O"):
new_dtype = StringDtype()
arr_cls = new_dtype.construct_array_type()
arr = arr_cls._from_sequence(arr, dtype=new_dtype)
Expand Down
21 changes: 19 additions & 2 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
from pandas.core.dtypes.common import (
is_dict_like,
is_list_like,
is_object_dtype,
is_string_dtype,
)
from pandas.core.dtypes.dtypes import (
ArrowDtype,
Expand All @@ -59,6 +61,7 @@
Series,
)
from pandas.core.arrays import ArrowExtensionArray
from pandas.core.arrays.string_ import StringDtype
from pandas.core.base import PandasObject
import pandas.core.common as com
from pandas.core.common import maybe_make_list
Expand Down Expand Up @@ -1331,7 +1334,12 @@ def _harmonize_columns(
elif dtype_backend == "numpy" and col_type is float:
# floats support NA, can always convert!
self.frame[col_name] = df_col.astype(col_type, copy=False)

elif (
using_string_dtype()
and is_string_dtype(col_type)
and is_object_dtype(self.frame[col_name])
):
self.frame[col_name] = df_col.astype(col_type, copy=False)
elif dtype_backend == "numpy" and len(df_col) == df_col.count():
# No NA values, can convert ints and bools
if col_type is np.dtype("int64") or col_type is bool:
Expand Down Expand Up @@ -1418,6 +1426,7 @@ def _get_dtype(self, sqltype):
DateTime,
Float,
Integer,
String,
)

if isinstance(sqltype, Float):
Expand All @@ -1437,6 +1446,10 @@ def _get_dtype(self, sqltype):
return date
elif isinstance(sqltype, Boolean):
return bool
elif isinstance(sqltype, String):
if using_string_dtype():
return StringDtype(na_value=np.nan)

return object


Expand Down Expand Up @@ -2218,7 +2231,7 @@ def read_table(
elif using_string_dtype():
from pandas.io._util import arrow_string_types_mapper

arrow_string_types_mapper()
mapping = arrow_string_types_mapper()
else:
mapping = None

Expand Down Expand Up @@ -2299,6 +2312,10 @@ def read_query(
from pandas.io._util import _arrow_dtype_mapping

mapping = _arrow_dtype_mapping().get
elif using_string_dtype():
from pandas.io._util import arrow_string_types_mapper

mapping = arrow_string_types_mapper()
else:
mapping = None

Expand Down
23 changes: 13 additions & 10 deletions pandas/tests/io/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
pytest.mark.filterwarnings(
"ignore:Passing a BlockManager to DataFrame:DeprecationWarning"
),
pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False),
pytest.mark.single_cpu,
]


Expand Down Expand Up @@ -685,6 +685,7 @@ def postgresql_psycopg2_conn(postgresql_psycopg2_engine):

@pytest.fixture
def postgresql_adbc_conn():
pytest.importorskip("pyarrow")
pytest.importorskip("adbc_driver_postgresql")
from adbc_driver_postgresql import dbapi

Expand Down Expand Up @@ -817,6 +818,7 @@ def sqlite_conn_types(sqlite_engine_types):

@pytest.fixture
def sqlite_adbc_conn():
pytest.importorskip("pyarrow")
pytest.importorskip("adbc_driver_sqlite")
from adbc_driver_sqlite import dbapi

Expand Down Expand Up @@ -986,13 +988,13 @@ def test_dataframe_to_sql(conn, test_frame1, request):

@pytest.mark.parametrize("conn", all_connectable)
def test_dataframe_to_sql_empty(conn, test_frame1, request):
if conn == "postgresql_adbc_conn":
if conn == "postgresql_adbc_conn" and not using_string_dtype():
request.node.add_marker(
pytest.mark.xfail(
reason="postgres ADBC driver cannot insert index with null type",
strict=True,
reason="postgres ADBC driver < 1.2 cannot insert index with null type",
)
)

# GH 51086 if conn is sqlite_engine
conn = request.getfixturevalue(conn)
empty_df = test_frame1.iloc[:0]
Expand Down Expand Up @@ -3571,7 +3573,8 @@ def test_read_sql_dtype_backend(
result = getattr(pd, func)(
f"Select * from {table}", conn, dtype_backend=dtype_backend
)
expected = dtype_backend_expected(string_storage, dtype_backend, conn_name)
expected = dtype_backend_expected(string_storage, dtype_backend, conn_name)

tm.assert_frame_equal(result, expected)

if "adbc" in conn_name:
Expand Down Expand Up @@ -3621,7 +3624,7 @@ def test_read_sql_dtype_backend_table(

with pd.option_context("mode.string_storage", string_storage):
result = getattr(pd, func)(table, conn, dtype_backend=dtype_backend)
expected = dtype_backend_expected(string_storage, dtype_backend, conn_name)
expected = dtype_backend_expected(string_storage, dtype_backend, conn_name)
tm.assert_frame_equal(result, expected)

if "adbc" in conn_name:
Expand Down Expand Up @@ -4150,7 +4153,7 @@ def tquery(query, con=None):
def test_xsqlite_basic(sqlite_buildin):
frame = DataFrame(
np.random.default_rng(2).standard_normal((10, 4)),
columns=Index(list("ABCD"), dtype=object),
columns=Index(list("ABCD")),
index=date_range("2000-01-01", periods=10, freq="B"),
)
assert sql.to_sql(frame, name="test_table", con=sqlite_buildin, index=False) == 10
Expand All @@ -4177,7 +4180,7 @@ def test_xsqlite_basic(sqlite_buildin):
def test_xsqlite_write_row_by_row(sqlite_buildin):
frame = DataFrame(
np.random.default_rng(2).standard_normal((10, 4)),
columns=Index(list("ABCD"), dtype=object),
columns=Index(list("ABCD")),
index=date_range("2000-01-01", periods=10, freq="B"),
)
frame.iloc[0, 0] = np.nan
Expand All @@ -4200,7 +4203,7 @@ def test_xsqlite_write_row_by_row(sqlite_buildin):
def test_xsqlite_execute(sqlite_buildin):
frame = DataFrame(
np.random.default_rng(2).standard_normal((10, 4)),
columns=Index(list("ABCD"), dtype=object),
columns=Index(list("ABCD")),
index=date_range("2000-01-01", periods=10, freq="B"),
)
create_sql = sql.get_schema(frame, "test")
Expand All @@ -4221,7 +4224,7 @@ def test_xsqlite_execute(sqlite_buildin):
def test_xsqlite_schema(sqlite_buildin):
frame = DataFrame(
np.random.default_rng(2).standard_normal((10, 4)),
columns=Index(list("ABCD"), dtype=object),
columns=Index(list("ABCD")),
index=date_range("2000-01-01", periods=10, freq="B"),
)
create_sql = sql.get_schema(frame, "test")
Expand Down
Loading