From 5035b116d8afa4e5436b3f650d6f96b558ad1cc1 Mon Sep 17 00:00:00 2001 From: Samuel Chai Date: Thu, 8 Feb 2024 21:28:29 -0500 Subject: [PATCH 01/11] Fixing multi method for to_sql for non-oracle databases --- pandas/io/sql.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 08f99a4d3093a..a484da77671f5 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -1013,12 +1013,19 @@ def _execute_insert_multi(self, conn, keys: list[str], data_iter) -> int: from sqlalchemy import insert + # For Oracle compliance we do not allow multi statements + dialects_not_supporting_multi = ["oracle"] + + if conn.dialect is not None and conn.dialect.name not in dialects_not_supporting_multi: + data = [dict(zip(keys, row)) for row in data_iter] + stmt = insert(self.table).values(data) + result = conn.execute(stmt) + return result.rowcount + + # For compliance with Oracle, use + # see: https:/ /docs.sqlalchemy.org/en/20/core/dml.html#sqlalchemy.sql.expression.Insert.values data = [dict(zip(keys, row)) for row in data_iter] stmt = insert(self.table) - # conn.execute is used here to ensure compatibility with Oracle. - # Using stmt.values(data) would produce a multi row insert that - # isn't supported by Oracle. - # see: https://docs.sqlalchemy.org/en/20/core/dml.html#sqlalchemy.sql.expression.Insert.values result = conn.execute(stmt, data) return result.rowcount From 7f4b5f8d9fbdf2a990deefbf3a772b3c8fc45dd1 Mon Sep 17 00:00:00 2001 From: Samuel Chai Date: Thu, 8 Feb 2024 21:32:38 -0500 Subject: [PATCH 02/11] Simplifying the if statement --- pandas/io/sql.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/pandas/io/sql.py b/pandas/io/sql.py index a484da77671f5..ed64712e5f6c4 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -1016,17 +1016,15 @@ def _execute_insert_multi(self, conn, keys: list[str], data_iter) -> int: # For Oracle compliance we do not allow multi statements dialects_not_supporting_multi = ["oracle"] + data = [dict(zip(keys, row)) for row in data_iter] + if conn.dialect is not None and conn.dialect.name not in dialects_not_supporting_multi: - data = [dict(zip(keys, row)) for row in data_iter] + # For Oracle compliance we do not allow multi statements stmt = insert(self.table).values(data) result = conn.execute(stmt) - return result.rowcount - - # For compliance with Oracle, use - # see: https:/ /docs.sqlalchemy.org/en/20/core/dml.html#sqlalchemy.sql.expression.Insert.values - data = [dict(zip(keys, row)) for row in data_iter] - stmt = insert(self.table) - result = conn.execute(stmt, data) + else: + stmt = insert(self.table) + result = conn.execute(stmt, data) return result.rowcount def insert_data(self) -> tuple[list[str], list[np.ndarray]]: From 641b144227593e2c43f4d2cad7270e45c039fead Mon Sep 17 00:00:00 2001 From: Samuel Chai Date: Thu, 8 Feb 2024 21:43:34 -0500 Subject: [PATCH 03/11] adding a doc --- doc/source/whatsnew/v2.2.1.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/source/whatsnew/v2.2.1.rst b/doc/source/whatsnew/v2.2.1.rst index 883627bd4b19b..b360e1233bdb4 100644 --- a/doc/source/whatsnew/v2.2.1.rst +++ b/doc/source/whatsnew/v2.2.1.rst @@ -26,6 +26,7 @@ Fixed regressions - Fixed regression in :meth:`Index.join` raising ``TypeError`` when joining an empty index to a non-empty index containing mixed dtype values (:issue:`57048`) - Fixed regression in :meth:`Series.pct_change` raising a ``ValueError`` for an empty :class:`Series` (:issue:`57056`) - Fixed regression in :meth:`Series.to_numpy` when dtype is given as float and the data contains NaNs (:issue:`57121`) +- Fixed regression in :meth:`DataFrame.to_sql` when method="multi" is passed and the dialect type is not Oracle (:issue:`57310`) .. --------------------------------------------------------------------------- .. _whatsnew_221.bug_fixes: From 0f846d4595484bc466ccffc6badbc559e00f05d4 Mon Sep 17 00:00:00 2001 From: Samuel Chai Date: Mon, 12 Feb 2024 09:56:38 -0500 Subject: [PATCH 04/11] Adding unit test --- pandas/io/sql.py | 15 +-- pandas/tests/io/test_sql.py | 247 ++++++++++++++++++++---------------- 2 files changed, 143 insertions(+), 119 deletions(-) diff --git a/pandas/io/sql.py b/pandas/io/sql.py index ed64712e5f6c4..55414f1cde6a4 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -1009,22 +1009,15 @@ def _execute_insert_multi(self, conn, keys: list[str], data_iter) -> int: Note: multi-value insert is usually faster for analytics DBs and tables containing a few columns but performance degrades quickly with increase of columns. + + Note: Oracle does not support multi-value insert """ from sqlalchemy import insert - # For Oracle compliance we do not allow multi statements - dialects_not_supporting_multi = ["oracle"] - data = [dict(zip(keys, row)) for row in data_iter] - - if conn.dialect is not None and conn.dialect.name not in dialects_not_supporting_multi: - # For Oracle compliance we do not allow multi statements - stmt = insert(self.table).values(data) - result = conn.execute(stmt) - else: - stmt = insert(self.table) - result = conn.execute(stmt, data) + stmt = insert(self.table).values(data) + result = conn.execute(stmt) return result.rowcount def insert_data(self) -> tuple[list[str], list[np.ndarray]]: diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index 8bb67fac19c65..43f369a461fb6 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re import contextlib from contextlib import closing import csv @@ -60,7 +61,6 @@ if TYPE_CHECKING: import sqlalchemy - pytestmark = pytest.mark.filterwarnings( "ignore:Passing a BlockManager to DataFrame:DeprecationWarning" ) @@ -500,7 +500,7 @@ def test_frame1(): def test_frame3(): columns = ["index", "A", "B"] data = [ - ("2000-01-03 00:00:00", 2**31 - 1, -1.987670), + ("2000-01-03 00:00:00", 2 ** 31 - 1, -1.987670), ("2000-01-04 00:00:00", -29, -0.0412318367011), ("2000-01-05 00:00:00", 20000, 0.731167677815), ("2000-01-06 00:00:00", -290867, 1.56762092543), @@ -558,8 +558,8 @@ def get_all_tables(conn): def drop_table( - table_name: str, - conn: sqlite3.Connection | sqlalchemy.engine.Engine | sqlalchemy.engine.Connection, + table_name: str, + conn: sqlite3.Connection | sqlalchemy.engine.Engine | sqlalchemy.engine.Connection, ): if isinstance(conn, sqlite3.Connection): conn.execute(f"DROP TABLE IF EXISTS {sql._get_valid_sqlite_name(table_name)}") @@ -577,8 +577,8 @@ def drop_table( def drop_view( - view_name: str, - conn: sqlite3.Connection | sqlalchemy.engine.Engine | sqlalchemy.engine.Connection, + view_name: str, + conn: sqlite3.Connection | sqlalchemy.engine.Engine | sqlalchemy.engine.Connection, ): import sqlalchemy @@ -942,11 +942,11 @@ def sqlite_buildin_types(sqlite_buildin, types_data): sqlalchemy_connectable = mysql_connectable + postgresql_connectable + sqlite_connectable sqlalchemy_connectable_iris = ( - mysql_connectable_iris + postgresql_connectable_iris + sqlite_connectable_iris + mysql_connectable_iris + postgresql_connectable_iris + sqlite_connectable_iris ) sqlalchemy_connectable_types = ( - mysql_connectable_types + postgresql_connectable_types + sqlite_connectable_types + mysql_connectable_types + postgresql_connectable_types + sqlite_connectable_types ) adbc_connectable = [ @@ -964,15 +964,14 @@ def sqlite_buildin_types(sqlite_buildin, types_data): pytest.param("sqlite_adbc_types", marks=pytest.mark.db), ] - all_connectable = sqlalchemy_connectable + ["sqlite_buildin"] + adbc_connectable all_connectable_iris = ( - sqlalchemy_connectable_iris + ["sqlite_buildin_iris"] + adbc_connectable_iris + sqlalchemy_connectable_iris + ["sqlite_buildin_iris"] + adbc_connectable_iris ) all_connectable_types = ( - sqlalchemy_connectable_types + ["sqlite_buildin_types"] + adbc_connectable_types + sqlalchemy_connectable_types + ["sqlite_buildin_types"] + adbc_connectable_types ) @@ -1783,15 +1782,15 @@ def test_api_date_parsing(conn, request): (sql.read_sql, "SELECT * FROM types", ("sqlalchemy", "fallback")), (sql.read_sql, "types", ("sqlalchemy")), ( - sql.read_sql_query, - "SELECT * FROM types", - ("sqlalchemy", "fallback"), + sql.read_sql_query, + "SELECT * FROM types", + ("sqlalchemy", "fallback"), ), (sql.read_sql_table, "types", ("sqlalchemy")), ], ) def test_api_custom_dateparsing_error( - conn, request, read_sql, text, mode, error, types_data_frame + conn, request, read_sql, text, mode, error, types_data_frame ): conn_name = conn conn = request.getfixturevalue(conn) @@ -2390,12 +2389,12 @@ def test_warning_case_insensitive_table_name(conn, request, test_frame1): conn = request.getfixturevalue(conn) # see gh-7815 with tm.assert_produces_warning( - UserWarning, - match=( - r"The provided table name 'TABLE1' is not found exactly as such in " - r"the database after writing the table, possibly due to case " - r"sensitivity issues. Consider using lower case table names." - ), + UserWarning, + match=( + r"The provided table name 'TABLE1' is not found exactly as such in " + r"the database after writing the table, possibly due to case " + r"sensitivity issues. Consider using lower case table names." + ), ): with sql.SQLDatabase(conn) as db: db.check_case_sensitive("TABLE1", "") @@ -2460,7 +2459,7 @@ def test_sqlalchemy_integer_overload_mapping(conn, request, integer): df = DataFrame([0, 1], columns=["a"], dtype=integer) with sql.SQLDatabase(conn) as db: with pytest.raises( - ValueError, match="Unsigned 64 bit integer datatype is not supported" + ValueError, match="Unsigned 64 bit integer datatype is not supported" ): sql.SQLTable("test_type", db, frame=df) @@ -2751,7 +2750,7 @@ def test_sqlalchemy_default_type_conversion(conn, request): def test_bigint(conn, request): # int64 should be converted to BigInteger, GH7433 conn = request.getfixturevalue(conn) - df = DataFrame(data={"i64": [2**62]}) + df = DataFrame(data={"i64": [2 ** 62]}) assert df.to_sql(name="test_bigint", con=conn, index=False) == 1 result = sql.read_sql_table("test_bigint", conn) @@ -2990,7 +2989,7 @@ def test_datetime_time(conn, request, sqlite_buildin): def test_mixed_dtype_insert(conn, request): # see GH6509 conn = request.getfixturevalue(conn) - s1 = Series(2**25 + 1, dtype=np.int32) + s1 = Series(2 ** 25 + 1, dtype=np.int32) s2 = Series(0.0, dtype=np.float32) df = DataFrame({"s1": s1, "s2": s2}) @@ -3288,14 +3287,14 @@ def test_double_precision(conn, request): ) assert ( - df.to_sql( - name="test_dtypes", - con=conn, - index=False, - if_exists="replace", - dtype={"f64_as_f32": Float(precision=23)}, - ) - == 1 + df.to_sql( + name="test_dtypes", + con=conn, + index=False, + if_exists="replace", + dtype={"f64_as_f32": Float(precision=23)}, + ) + == 1 ) res = sql.read_sql_table("test_dtypes", conn) @@ -3343,8 +3342,8 @@ def main(connectable): test_connectable(connectable) assert ( - DataFrame({"test_foo_data": [0, 1, 2]}).to_sql(name="test_foo_data", con=conn) - == 3 + DataFrame({"test_foo_data": [0, 1, 2]}).to_sql(name="test_foo_data", con=conn) + == 3 ) main(conn) @@ -3503,13 +3502,13 @@ def test_get_engine_auto_error_message(): @pytest.mark.parametrize("conn", all_connectable) @pytest.mark.parametrize("func", ["read_sql", "read_sql_query"]) def test_read_sql_dtype_backend( - conn, - request, - string_storage, - func, - dtype_backend, - dtype_backend_data, - dtype_backend_expected, + conn, + request, + string_storage, + func, + dtype_backend, + dtype_backend_data, + dtype_backend_expected, ): # GH#50048 conn_name = conn @@ -3546,13 +3545,13 @@ def test_read_sql_dtype_backend( @pytest.mark.parametrize("conn", all_connectable) @pytest.mark.parametrize("func", ["read_sql", "read_sql_table"]) def test_read_sql_dtype_backend_table( - conn, - request, - string_storage, - func, - dtype_backend, - dtype_backend_data, - dtype_backend_expected, + conn, + request, + string_storage, + func, + dtype_backend, + dtype_backend_data, + dtype_backend_expected, ): if "sqlite" in conn and "adbc" not in conn: request.applymarker( @@ -3693,10 +3692,10 @@ def test_chunksize_empty_dtypes(conn, request): df.to_sql(name="test", con=conn, index=False, if_exists="replace") for result in read_sql_query( - "SELECT * FROM test", - conn, - dtype=dtypes, - chunksize=1, + "SELECT * FROM test", + conn, + dtype=dtypes, + chunksize=1, ): tm.assert_frame_equal(result, expected) @@ -3775,8 +3774,8 @@ class Test(BaseModel): with Session() as session: df = DataFrame({"id": [0, 1], "string_column": ["hello", "world"]}) assert ( - df.to_sql(name="test_frame", con=conn, index=False, if_exists="replace") - == 2 + df.to_sql(name="test_frame", con=conn, index=False, if_exists="replace") + == 2 ) session.commit() test_query = session.query(Test.id, Test.string_column) @@ -3816,7 +3815,7 @@ def test_roundtripping_datetimes(sqlite_engine): @pytest.fixture def sqlite_builtin_detect_types(): with contextlib.closing( - sqlite3.connect(":memory:", detect_types=sqlite3.PARSE_DECLTYPES) + sqlite3.connect(":memory:", detect_types=sqlite3.PARSE_DECLTYPES) ) as closing_conn: with closing_conn as conn: yield conn @@ -3848,16 +3847,16 @@ def test_psycopg2_schema_support(postgresql_psycopg2_engine): # write dataframe to different schema's assert df.to_sql(name="test_schema_public", con=conn, index=False) == 2 assert ( - df.to_sql( - name="test_schema_public_explicit", - con=conn, - index=False, - schema="public", - ) - == 2 + df.to_sql( + name="test_schema_public_explicit", + con=conn, + index=False, + schema="public", + ) + == 2 ) assert ( - df.to_sql(name="test_schema_other", con=conn, index=False, schema="other") == 2 + df.to_sql(name="test_schema_other", con=conn, index=False, schema="other") == 2 ) # read dataframes back in @@ -3883,7 +3882,7 @@ def test_psycopg2_schema_support(postgresql_psycopg2_engine): # write dataframe with different if_exists options assert ( - df.to_sql(name="test_schema_other", con=conn, schema="other", index=False) == 2 + df.to_sql(name="test_schema_other", con=conn, schema="other", index=False) == 2 ) df.to_sql( name="test_schema_other", @@ -3893,14 +3892,14 @@ def test_psycopg2_schema_support(postgresql_psycopg2_engine): if_exists="replace", ) assert ( - df.to_sql( - name="test_schema_other", - con=conn, - schema="other", - index=False, - if_exists="append", - ) - == 2 + df.to_sql( + name="test_schema_other", + con=conn, + schema="other", + index=False, + if_exists="append", + ) + == 2 ) res = sql.read_sql_table("test_schema_other", conn, schema="other") tm.assert_frame_equal(concat([df, df], ignore_index=True), res) @@ -4044,18 +4043,18 @@ def test_sqlite_illegal_names(sqlite_buildin): df.to_sql(name="", con=conn) for ndx, weird_name in enumerate( - [ - "test_weird_name]", - "test_weird_name[", - "test_weird_name`", - 'test_weird_name"', - "test_weird_name'", - "_b.test_weird_name_01-30", - '"_b.test_weird_name_01-30"', - "99beginswithnumber", - "12345", - "\xe9", - ] + [ + "test_weird_name]", + "test_weird_name[", + "test_weird_name`", + 'test_weird_name"', + "test_weird_name'", + "_b.test_weird_name_01-30", + '"_b.test_weird_name_01-30"', + "99beginswithnumber", + "12345", + "\xe9", + ] ): assert df.to_sql(name=weird_name, con=conn) == 2 sql.table_exists(weird_name, conn) @@ -4289,39 +4288,39 @@ def test_xsqlite_if_exists(sqlite_buildin): ) assert tquery(sql_select, con=sqlite_buildin) == [(1, "A"), (2, "B")] assert ( - sql.to_sql( - frame=df_if_exists_2, - con=sqlite_buildin, - name=table_name, - if_exists="replace", - index=False, - ) - == 3 + sql.to_sql( + frame=df_if_exists_2, + con=sqlite_buildin, + name=table_name, + if_exists="replace", + index=False, + ) + == 3 ) assert tquery(sql_select, con=sqlite_buildin) == [(3, "C"), (4, "D"), (5, "E")] drop_table(table_name, sqlite_buildin) # test if_exists='append' assert ( - sql.to_sql( - frame=df_if_exists_1, - con=sqlite_buildin, - name=table_name, - if_exists="fail", - index=False, - ) - == 2 + sql.to_sql( + frame=df_if_exists_1, + con=sqlite_buildin, + name=table_name, + if_exists="fail", + index=False, + ) + == 2 ) assert tquery(sql_select, con=sqlite_buildin) == [(1, "A"), (2, "B")] assert ( - sql.to_sql( - frame=df_if_exists_2, - con=sqlite_buildin, - name=table_name, - if_exists="append", - index=False, - ) - == 3 + sql.to_sql( + frame=df_if_exists_2, + con=sqlite_buildin, + name=table_name, + if_exists="append", + index=False, + ) + == 3 ) assert tquery(sql_select, con=sqlite_buildin) == [ (1, "A"), @@ -4331,3 +4330,35 @@ def test_xsqlite_if_exists(sqlite_buildin): (5, "E"), ] drop_table(table_name, sqlite_buildin) + + +def test_execution_of_multi(mysql_pymysql_engine): + + from sqlalchemy import event + from pandas.io.sql import SQLTable + original_function = SQLTable._execute_insert_multi + + frame = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD"), dtype=object)) + + statements = [] + + def track_statements(_, __, statement, ___, ____, _____): + nonlocal statements + statements.append(statement) + + def pandas_insert_patched(self, *args, **kwargs): + event.listen(args[0], "before_cursor_execute", track_statements) + return original_function(self, *args, **kwargs) + + pd.io.sql.SQLTable._execute_insert_multi = pandas_insert_patched + + frame.to_sql("test_multi_prepared_statement", mysql_pymysql_engine, method="multi", index=False) + sql_statement = statements[-1] + + pattern = r'\([^()]+\)' + + matches = re.findall(pattern, sql_statement) + + assert len([a for a in matches if a.startswith("(A_")]) > 1 From 51cadc9f475c09aa6f72d8e4de47878b5cf450ee Mon Sep 17 00:00:00 2001 From: Samuel Chai <121340503+kassett@users.noreply.github.com> Date: Mon, 12 Feb 2024 14:58:56 -0500 Subject: [PATCH 05/11] Update doc/source/whatsnew/v2.2.1.rst Co-authored-by: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> --- doc/source/whatsnew/v2.2.1.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v2.2.1.rst b/doc/source/whatsnew/v2.2.1.rst index b360e1233bdb4..f52d4673148f0 100644 --- a/doc/source/whatsnew/v2.2.1.rst +++ b/doc/source/whatsnew/v2.2.1.rst @@ -26,7 +26,7 @@ Fixed regressions - Fixed regression in :meth:`Index.join` raising ``TypeError`` when joining an empty index to a non-empty index containing mixed dtype values (:issue:`57048`) - Fixed regression in :meth:`Series.pct_change` raising a ``ValueError`` for an empty :class:`Series` (:issue:`57056`) - Fixed regression in :meth:`Series.to_numpy` when dtype is given as float and the data contains NaNs (:issue:`57121`) -- Fixed regression in :meth:`DataFrame.to_sql` when method="multi" is passed and the dialect type is not Oracle (:issue:`57310`) +- Fixed regression in :meth:`DataFrame.to_sql` when ``method="multi"`` is passed and the dialect type is not Oracle (:issue:`57310`) .. --------------------------------------------------------------------------- .. _whatsnew_221.bug_fixes: From 22c3b86dc128dd857555448bbc5a11ba8acfffdd Mon Sep 17 00:00:00 2001 From: Samuel Chai Date: Mon, 12 Feb 2024 15:14:03 -0500 Subject: [PATCH 06/11] Reverted formatting in test_sql --- pandas/core/generic.py | 3 + pandas/io/sql.py | 3 +- pandas/tests/io/test_sql.py | 216 ++++++++++++++++++------------------ 3 files changed, 113 insertions(+), 109 deletions(-) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 3c71784ad81c4..4ec586eb16153 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -2912,6 +2912,9 @@ def to_sql( ``Timestamp with timezone`` type with SQLAlchemy if supported by the database. Otherwise, the datetimes will be stored as timezone unaware timestamps local to the original timezone. + + Not all datastores support ``method="multi"``. Oracle, for example, + does not support multi-value insert References ---------- diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 55414f1cde6a4..aff88712c9dc7 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -1004,13 +1004,12 @@ def _execute_insert(self, conn, keys: list[str], data_iter) -> int: def _execute_insert_multi(self, conn, keys: list[str], data_iter) -> int: """ - Alternative to _execute_insert for DBs support multivalue INSERT. + Alternative to _execute_insert for DBs support multi-value INSERT. Note: multi-value insert is usually faster for analytics DBs and tables containing a few columns but performance degrades quickly with increase of columns. - Note: Oracle does not support multi-value insert """ from sqlalchemy import insert diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index 43f369a461fb6..4fe330288f46f 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -61,6 +61,7 @@ if TYPE_CHECKING: import sqlalchemy + pytestmark = pytest.mark.filterwarnings( "ignore:Passing a BlockManager to DataFrame:DeprecationWarning" ) @@ -500,7 +501,7 @@ def test_frame1(): def test_frame3(): columns = ["index", "A", "B"] data = [ - ("2000-01-03 00:00:00", 2 ** 31 - 1, -1.987670), + ("2000-01-03 00:00:00", 2**31 - 1, -1.987670), ("2000-01-04 00:00:00", -29, -0.0412318367011), ("2000-01-05 00:00:00", 20000, 0.731167677815), ("2000-01-06 00:00:00", -290867, 1.56762092543), @@ -558,8 +559,8 @@ def get_all_tables(conn): def drop_table( - table_name: str, - conn: sqlite3.Connection | sqlalchemy.engine.Engine | sqlalchemy.engine.Connection, + table_name: str, + conn: sqlite3.Connection | sqlalchemy.engine.Engine | sqlalchemy.engine.Connection, ): if isinstance(conn, sqlite3.Connection): conn.execute(f"DROP TABLE IF EXISTS {sql._get_valid_sqlite_name(table_name)}") @@ -577,8 +578,8 @@ def drop_table( def drop_view( - view_name: str, - conn: sqlite3.Connection | sqlalchemy.engine.Engine | sqlalchemy.engine.Connection, + view_name: str, + conn: sqlite3.Connection | sqlalchemy.engine.Engine | sqlalchemy.engine.Connection, ): import sqlalchemy @@ -942,11 +943,11 @@ def sqlite_buildin_types(sqlite_buildin, types_data): sqlalchemy_connectable = mysql_connectable + postgresql_connectable + sqlite_connectable sqlalchemy_connectable_iris = ( - mysql_connectable_iris + postgresql_connectable_iris + sqlite_connectable_iris + mysql_connectable_iris + postgresql_connectable_iris + sqlite_connectable_iris ) sqlalchemy_connectable_types = ( - mysql_connectable_types + postgresql_connectable_types + sqlite_connectable_types + mysql_connectable_types + postgresql_connectable_types + sqlite_connectable_types ) adbc_connectable = [ @@ -964,14 +965,15 @@ def sqlite_buildin_types(sqlite_buildin, types_data): pytest.param("sqlite_adbc_types", marks=pytest.mark.db), ] + all_connectable = sqlalchemy_connectable + ["sqlite_buildin"] + adbc_connectable all_connectable_iris = ( - sqlalchemy_connectable_iris + ["sqlite_buildin_iris"] + adbc_connectable_iris + sqlalchemy_connectable_iris + ["sqlite_buildin_iris"] + adbc_connectable_iris ) all_connectable_types = ( - sqlalchemy_connectable_types + ["sqlite_buildin_types"] + adbc_connectable_types + sqlalchemy_connectable_types + ["sqlite_buildin_types"] + adbc_connectable_types ) @@ -1782,15 +1784,15 @@ def test_api_date_parsing(conn, request): (sql.read_sql, "SELECT * FROM types", ("sqlalchemy", "fallback")), (sql.read_sql, "types", ("sqlalchemy")), ( - sql.read_sql_query, - "SELECT * FROM types", - ("sqlalchemy", "fallback"), + sql.read_sql_query, + "SELECT * FROM types", + ("sqlalchemy", "fallback"), ), (sql.read_sql_table, "types", ("sqlalchemy")), ], ) def test_api_custom_dateparsing_error( - conn, request, read_sql, text, mode, error, types_data_frame + conn, request, read_sql, text, mode, error, types_data_frame ): conn_name = conn conn = request.getfixturevalue(conn) @@ -2389,12 +2391,12 @@ def test_warning_case_insensitive_table_name(conn, request, test_frame1): conn = request.getfixturevalue(conn) # see gh-7815 with tm.assert_produces_warning( - UserWarning, - match=( - r"The provided table name 'TABLE1' is not found exactly as such in " - r"the database after writing the table, possibly due to case " - r"sensitivity issues. Consider using lower case table names." - ), + UserWarning, + match=( + r"The provided table name 'TABLE1' is not found exactly as such in " + r"the database after writing the table, possibly due to case " + r"sensitivity issues. Consider using lower case table names." + ), ): with sql.SQLDatabase(conn) as db: db.check_case_sensitive("TABLE1", "") @@ -2459,7 +2461,7 @@ def test_sqlalchemy_integer_overload_mapping(conn, request, integer): df = DataFrame([0, 1], columns=["a"], dtype=integer) with sql.SQLDatabase(conn) as db: with pytest.raises( - ValueError, match="Unsigned 64 bit integer datatype is not supported" + ValueError, match="Unsigned 64 bit integer datatype is not supported" ): sql.SQLTable("test_type", db, frame=df) @@ -2750,7 +2752,7 @@ def test_sqlalchemy_default_type_conversion(conn, request): def test_bigint(conn, request): # int64 should be converted to BigInteger, GH7433 conn = request.getfixturevalue(conn) - df = DataFrame(data={"i64": [2 ** 62]}) + df = DataFrame(data={"i64": [2**62]}) assert df.to_sql(name="test_bigint", con=conn, index=False) == 1 result = sql.read_sql_table("test_bigint", conn) @@ -2989,7 +2991,7 @@ def test_datetime_time(conn, request, sqlite_buildin): def test_mixed_dtype_insert(conn, request): # see GH6509 conn = request.getfixturevalue(conn) - s1 = Series(2 ** 25 + 1, dtype=np.int32) + s1 = Series(2**25 + 1, dtype=np.int32) s2 = Series(0.0, dtype=np.float32) df = DataFrame({"s1": s1, "s2": s2}) @@ -3287,14 +3289,14 @@ def test_double_precision(conn, request): ) assert ( - df.to_sql( - name="test_dtypes", - con=conn, - index=False, - if_exists="replace", - dtype={"f64_as_f32": Float(precision=23)}, - ) - == 1 + df.to_sql( + name="test_dtypes", + con=conn, + index=False, + if_exists="replace", + dtype={"f64_as_f32": Float(precision=23)}, + ) + == 1 ) res = sql.read_sql_table("test_dtypes", conn) @@ -3342,8 +3344,8 @@ def main(connectable): test_connectable(connectable) assert ( - DataFrame({"test_foo_data": [0, 1, 2]}).to_sql(name="test_foo_data", con=conn) - == 3 + DataFrame({"test_foo_data": [0, 1, 2]}).to_sql(name="test_foo_data", con=conn) + == 3 ) main(conn) @@ -3502,13 +3504,13 @@ def test_get_engine_auto_error_message(): @pytest.mark.parametrize("conn", all_connectable) @pytest.mark.parametrize("func", ["read_sql", "read_sql_query"]) def test_read_sql_dtype_backend( - conn, - request, - string_storage, - func, - dtype_backend, - dtype_backend_data, - dtype_backend_expected, + conn, + request, + string_storage, + func, + dtype_backend, + dtype_backend_data, + dtype_backend_expected, ): # GH#50048 conn_name = conn @@ -3545,13 +3547,13 @@ def test_read_sql_dtype_backend( @pytest.mark.parametrize("conn", all_connectable) @pytest.mark.parametrize("func", ["read_sql", "read_sql_table"]) def test_read_sql_dtype_backend_table( - conn, - request, - string_storage, - func, - dtype_backend, - dtype_backend_data, - dtype_backend_expected, + conn, + request, + string_storage, + func, + dtype_backend, + dtype_backend_data, + dtype_backend_expected, ): if "sqlite" in conn and "adbc" not in conn: request.applymarker( @@ -3692,10 +3694,10 @@ def test_chunksize_empty_dtypes(conn, request): df.to_sql(name="test", con=conn, index=False, if_exists="replace") for result in read_sql_query( - "SELECT * FROM test", - conn, - dtype=dtypes, - chunksize=1, + "SELECT * FROM test", + conn, + dtype=dtypes, + chunksize=1, ): tm.assert_frame_equal(result, expected) @@ -3774,8 +3776,8 @@ class Test(BaseModel): with Session() as session: df = DataFrame({"id": [0, 1], "string_column": ["hello", "world"]}) assert ( - df.to_sql(name="test_frame", con=conn, index=False, if_exists="replace") - == 2 + df.to_sql(name="test_frame", con=conn, index=False, if_exists="replace") + == 2 ) session.commit() test_query = session.query(Test.id, Test.string_column) @@ -3815,7 +3817,7 @@ def test_roundtripping_datetimes(sqlite_engine): @pytest.fixture def sqlite_builtin_detect_types(): with contextlib.closing( - sqlite3.connect(":memory:", detect_types=sqlite3.PARSE_DECLTYPES) + sqlite3.connect(":memory:", detect_types=sqlite3.PARSE_DECLTYPES) ) as closing_conn: with closing_conn as conn: yield conn @@ -3847,16 +3849,16 @@ def test_psycopg2_schema_support(postgresql_psycopg2_engine): # write dataframe to different schema's assert df.to_sql(name="test_schema_public", con=conn, index=False) == 2 assert ( - df.to_sql( - name="test_schema_public_explicit", - con=conn, - index=False, - schema="public", - ) - == 2 + df.to_sql( + name="test_schema_public_explicit", + con=conn, + index=False, + schema="public", + ) + == 2 ) assert ( - df.to_sql(name="test_schema_other", con=conn, index=False, schema="other") == 2 + df.to_sql(name="test_schema_other", con=conn, index=False, schema="other") == 2 ) # read dataframes back in @@ -3882,7 +3884,7 @@ def test_psycopg2_schema_support(postgresql_psycopg2_engine): # write dataframe with different if_exists options assert ( - df.to_sql(name="test_schema_other", con=conn, schema="other", index=False) == 2 + df.to_sql(name="test_schema_other", con=conn, schema="other", index=False) == 2 ) df.to_sql( name="test_schema_other", @@ -3892,14 +3894,14 @@ def test_psycopg2_schema_support(postgresql_psycopg2_engine): if_exists="replace", ) assert ( - df.to_sql( - name="test_schema_other", - con=conn, - schema="other", - index=False, - if_exists="append", - ) - == 2 + df.to_sql( + name="test_schema_other", + con=conn, + schema="other", + index=False, + if_exists="append", + ) + == 2 ) res = sql.read_sql_table("test_schema_other", conn, schema="other") tm.assert_frame_equal(concat([df, df], ignore_index=True), res) @@ -4043,18 +4045,18 @@ def test_sqlite_illegal_names(sqlite_buildin): df.to_sql(name="", con=conn) for ndx, weird_name in enumerate( - [ - "test_weird_name]", - "test_weird_name[", - "test_weird_name`", - 'test_weird_name"', - "test_weird_name'", - "_b.test_weird_name_01-30", - '"_b.test_weird_name_01-30"', - "99beginswithnumber", - "12345", - "\xe9", - ] + [ + "test_weird_name]", + "test_weird_name[", + "test_weird_name`", + 'test_weird_name"', + "test_weird_name'", + "_b.test_weird_name_01-30", + '"_b.test_weird_name_01-30"', + "99beginswithnumber", + "12345", + "\xe9", + ] ): assert df.to_sql(name=weird_name, con=conn) == 2 sql.table_exists(weird_name, conn) @@ -4288,39 +4290,39 @@ def test_xsqlite_if_exists(sqlite_buildin): ) assert tquery(sql_select, con=sqlite_buildin) == [(1, "A"), (2, "B")] assert ( - sql.to_sql( - frame=df_if_exists_2, - con=sqlite_buildin, - name=table_name, - if_exists="replace", - index=False, - ) - == 3 + sql.to_sql( + frame=df_if_exists_2, + con=sqlite_buildin, + name=table_name, + if_exists="replace", + index=False, + ) + == 3 ) assert tquery(sql_select, con=sqlite_buildin) == [(3, "C"), (4, "D"), (5, "E")] drop_table(table_name, sqlite_buildin) # test if_exists='append' assert ( - sql.to_sql( - frame=df_if_exists_1, - con=sqlite_buildin, - name=table_name, - if_exists="fail", - index=False, - ) - == 2 + sql.to_sql( + frame=df_if_exists_1, + con=sqlite_buildin, + name=table_name, + if_exists="fail", + index=False, + ) + == 2 ) assert tquery(sql_select, con=sqlite_buildin) == [(1, "A"), (2, "B")] assert ( - sql.to_sql( - frame=df_if_exists_2, - con=sqlite_buildin, - name=table_name, - if_exists="append", - index=False, - ) - == 3 + sql.to_sql( + frame=df_if_exists_2, + con=sqlite_buildin, + name=table_name, + if_exists="append", + index=False, + ) + == 3 ) assert tquery(sql_select, con=sqlite_buildin) == [ (1, "A"), @@ -4361,4 +4363,4 @@ def pandas_insert_patched(self, *args, **kwargs): matches = re.findall(pattern, sql_statement) - assert len([a for a in matches if a.startswith("(A_")]) > 1 + assert len([a for a in matches if a.startswith("(A_")]) > 1 \ No newline at end of file From 8b6fd74734cb3cd99cf23f72a3ccfecd48a5aeeb Mon Sep 17 00:00:00 2001 From: Samuel Chai Date: Mon, 12 Feb 2024 20:57:09 -0500 Subject: [PATCH 07/11] Simplifying unit test --- pandas/tests/io/test_sql.py | 37 +++++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index 4fe330288f46f..da24ee346275c 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -4335,32 +4335,37 @@ def test_xsqlite_if_exists(sqlite_buildin): def test_execution_of_multi(mysql_pymysql_engine): - - from sqlalchemy import event from pandas.io.sql import SQLTable + from sqlalchemy import event original_function = SQLTable._execute_insert_multi - frame = DataFrame( - np.random.default_rng(2).standard_normal((10, 4)), - columns=Index(list("ABCD"), dtype=object)) + frame = DataFrame(np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD"), dtype=object)) - statements = [] + # Track whether execute_many is True for the statements + # Multi-value inserts will be a single statement and therefore + # ``execute_many`` will be False + execute_many_types = [] - def track_statements(_, __, statement, ___, ____, _____): - nonlocal statements - statements.append(statement) + def track_statements(_, __, ___, ____, _____, execute_many): + nonlocal execute_many_types + execute_many_types.append(execute_many) + # A connection is the first argument passed to the _execute_insert_multi + # function. Add an event listener to this connection def pandas_insert_patched(self, *args, **kwargs): event.listen(args[0], "before_cursor_execute", track_statements) return original_function(self, *args, **kwargs) + # Patch this function onto the insert function to capture the event listener pd.io.sql.SQLTable._execute_insert_multi = pandas_insert_patched - frame.to_sql("test_multi_prepared_statement", mysql_pymysql_engine, method="multi", index=False) - sql_statement = statements[-1] - - pattern = r'\([^()]+\)' - - matches = re.findall(pattern, sql_statement) + frame.to_sql("test_multi_prepared_statement", + mysql_pymysql_engine, + if_exists="append", + method="multi", + index=False) - assert len([a for a in matches if a.startswith("(A_")]) > 1 \ No newline at end of file + # The last statement executed will be the insert statement + # Ensure that this is not using ``execute_many`` + assert not execute_many_types[-1] From cf8be43bebde3fc939d4257862cf6efe3a8868df Mon Sep 17 00:00:00 2001 From: Samuel Chai Date: Tue, 13 Feb 2024 14:56:10 -0500 Subject: [PATCH 08/11] Removing unit test --- pandas/tests/io/test_sql.py | 37 ------------------------------------- 1 file changed, 37 deletions(-) diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index da24ee346275c..49bf3e705aef9 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -4332,40 +4332,3 @@ def test_xsqlite_if_exists(sqlite_buildin): (5, "E"), ] drop_table(table_name, sqlite_buildin) - - -def test_execution_of_multi(mysql_pymysql_engine): - from pandas.io.sql import SQLTable - from sqlalchemy import event - original_function = SQLTable._execute_insert_multi - - frame = DataFrame(np.random.default_rng(2).standard_normal((10, 4)), - columns=Index(list("ABCD"), dtype=object)) - - # Track whether execute_many is True for the statements - # Multi-value inserts will be a single statement and therefore - # ``execute_many`` will be False - execute_many_types = [] - - def track_statements(_, __, ___, ____, _____, execute_many): - nonlocal execute_many_types - execute_many_types.append(execute_many) - - # A connection is the first argument passed to the _execute_insert_multi - # function. Add an event listener to this connection - def pandas_insert_patched(self, *args, **kwargs): - event.listen(args[0], "before_cursor_execute", track_statements) - return original_function(self, *args, **kwargs) - - # Patch this function onto the insert function to capture the event listener - pd.io.sql.SQLTable._execute_insert_multi = pandas_insert_patched - - frame.to_sql("test_multi_prepared_statement", - mysql_pymysql_engine, - if_exists="append", - method="multi", - index=False) - - # The last statement executed will be the insert statement - # Ensure that this is not using ``execute_many`` - assert not execute_many_types[-1] From a7a1cc55c51ef191c4b11c6952de1587326a0529 Mon Sep 17 00:00:00 2001 From: Samuel Chai Date: Tue, 13 Feb 2024 19:41:39 -0500 Subject: [PATCH 09/11] remove trailing whitespaces --- pandas/core/generic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 4ec586eb16153..72d970a31c680 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -2913,8 +2913,8 @@ def to_sql( database. Otherwise, the datetimes will be stored as timezone unaware timestamps local to the original timezone. - Not all datastores support ``method="multi"``. Oracle, for example, - does not support multi-value insert + Not all datastores support ``method="multi"``. Oracle, for example, + does not support multi-value insert References ---------- From eefe5c3c062f1024f0051f25b2602d1cd92ea05a Mon Sep 17 00:00:00 2001 From: Samuel Chai Date: Wed, 14 Feb 2024 00:07:13 -0500 Subject: [PATCH 10/11] Removing trailing whitespace --- pandas/core/generic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 72d970a31c680..f21422dfe5b65 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -2914,7 +2914,7 @@ def to_sql( timestamps local to the original timezone. Not all datastores support ``method="multi"``. Oracle, for example, - does not support multi-value insert + does not support multi-value insert. References ---------- From 5462c7c9aa7898526bc633fabf0487abacdee1f1 Mon Sep 17 00:00:00 2001 From: Samuel Chai Date: Fri, 16 Feb 2024 18:08:42 -0500 Subject: [PATCH 11/11] fixing alpahbetical sorting --- doc/source/whatsnew/v2.2.1.rst | 2 +- pandas/core/generic.py | 2 +- pandas/tests/io/test_sql.py | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/doc/source/whatsnew/v2.2.1.rst b/doc/source/whatsnew/v2.2.1.rst index f52d4673148f0..e6f011ac14afd 100644 --- a/doc/source/whatsnew/v2.2.1.rst +++ b/doc/source/whatsnew/v2.2.1.rst @@ -21,12 +21,12 @@ Fixed regressions - Fixed regression in :meth:`DataFrame.sort_index` not producing a stable sort for a index with duplicates (:issue:`57151`) - Fixed regression in :meth:`DataFrame.to_dict` with ``orient='list'`` and datetime or timedelta types returning integers (:issue:`54824`) - Fixed regression in :meth:`DataFrame.to_json` converting nullable integers to floats (:issue:`57224`) +- Fixed regression in :meth:`DataFrame.to_sql` when ``method="multi"`` is passed and the dialect type is not Oracle (:issue:`57310`) - Fixed regression in :meth:`DataFrameGroupBy.idxmin`, :meth:`DataFrameGroupBy.idxmax`, :meth:`SeriesGroupBy.idxmin`, :meth:`SeriesGroupBy.idxmax` ignoring the ``skipna`` argument (:issue:`57040`) - Fixed regression in :meth:`DataFrameGroupBy.idxmin`, :meth:`DataFrameGroupBy.idxmax`, :meth:`SeriesGroupBy.idxmin`, :meth:`SeriesGroupBy.idxmax` where values containing the minimum or maximum value for the dtype could produce incorrect results (:issue:`57040`) - Fixed regression in :meth:`Index.join` raising ``TypeError`` when joining an empty index to a non-empty index containing mixed dtype values (:issue:`57048`) - Fixed regression in :meth:`Series.pct_change` raising a ``ValueError`` for an empty :class:`Series` (:issue:`57056`) - Fixed regression in :meth:`Series.to_numpy` when dtype is given as float and the data contains NaNs (:issue:`57121`) -- Fixed regression in :meth:`DataFrame.to_sql` when ``method="multi"`` is passed and the dialect type is not Oracle (:issue:`57310`) .. --------------------------------------------------------------------------- .. _whatsnew_221.bug_fixes: diff --git a/pandas/core/generic.py b/pandas/core/generic.py index f21422dfe5b65..d624877f42262 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -2912,7 +2912,7 @@ def to_sql( ``Timestamp with timezone`` type with SQLAlchemy if supported by the database. Otherwise, the datetimes will be stored as timezone unaware timestamps local to the original timezone. - + Not all datastores support ``method="multi"``. Oracle, for example, does not support multi-value insert. diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index 49bf3e705aef9..8bb67fac19c65 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -1,6 +1,5 @@ from __future__ import annotations -import re import contextlib from contextlib import closing import csv