From 3373c63dc9a3c880c4b9fe3e042ba6a40a79249a Mon Sep 17 00:00:00 2001 From: Fangchen Li Date: Sun, 16 Jan 2022 22:54:36 -0600 Subject: [PATCH 1/2] CLN: remove sqlalchemy<14 compat --- pandas/io/sql.py | 48 ++++------------- pandas/tests/io/test_sql.py | 103 ++++++++++-------------------------- 2 files changed, 38 insertions(+), 113 deletions(-) diff --git a/pandas/io/sql.py b/pandas/io/sql.py index f7fdc47afa8d1..b723eea334e84 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -46,7 +46,6 @@ from pandas.core.base import PandasObject import pandas.core.common as com from pandas.core.tools.datetimes import to_datetime -from pandas.util.version import Version class DatabaseError(OSError): @@ -57,16 +56,6 @@ class DatabaseError(OSError): # -- Helper functions -def _gt14() -> bool: - """ - Check if sqlalchemy.__version__ is at least 1.4.0, when several - deprecations were made. - """ - import sqlalchemy - - return Version(sqlalchemy.__version__) >= Version("1.4.0") - - def _convert_params(sql, params): """Convert SQL and params args to DBAPI2.0 compliant format.""" args = [sql] @@ -814,10 +803,7 @@ def sql_schema(self): def _execute_create(self): # Inserting table into database, add to MetaData object - if _gt14(): - self.table = self.table.to_metadata(self.pd_sql.meta) - else: - self.table = self.table.tometadata(self.pd_sql.meta) + self.table = self.table.to_metadata(self.pd_sql.meta) self.table.create(bind=self.pd_sql.connectable) def create(self): @@ -986,10 +972,9 @@ def read(self, coerce_float=True, parse_dates=None, columns=None, chunksize=None if self.index is not None: for idx in self.index[::-1]: cols.insert(0, self.table.c[idx]) - sql_select = select(*cols) if _gt14() else select(cols) + sql_select = select(*cols) else: - sql_select = select(self.table) if _gt14() else self.table.select() - + sql_select = select(self.table) result = self.pd_sql.execute(sql_select) column_names = result.keys() @@ -1633,19 +1618,11 @@ def check_case_sensitive( if not name.isdigit() and not name.islower(): # check for potentially case sensitivity issues (GH7815) # Only check when name is not a number and name is not lower case - engine = self.connectable.engine - with self.connectable.connect() as conn: - if _gt14(): - from sqlalchemy import inspect + from sqlalchemy import inspect - insp = inspect(conn) - table_names = insp.get_table_names( - schema=schema or self.meta.schema - ) - else: - table_names = engine.table_names( - schema=schema or self.meta.schema, connection=conn - ) + with self.connectable.connect() as conn: + insp = inspect(conn) + table_names = insp.get_table_names(schema=schema or self.meta.schema) if name not in table_names: msg = ( f"The provided table name '{name}' is not found exactly as " @@ -1749,15 +1726,10 @@ def tables(self): return self.meta.tables def has_table(self, name: str, schema: str | None = None): - if _gt14(): - from sqlalchemy import inspect + from sqlalchemy import inspect - insp = inspect(self.connectable) - return insp.has_table(name, schema or self.meta.schema) - else: - return self.connectable.run_callable( - self.connectable.dialect.has_table, name, schema or self.meta.schema - ) + insp = inspect(self.connectable) + return insp.has_table(name, schema or self.meta.schema) def get_table(self, table_name: str, schema: str | None = None): from sqlalchemy import ( diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index 7a94797543519..cdb26bc9d1825 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -31,8 +31,6 @@ import numpy as np import pytest -import pandas.util._test_decorators as td - from pandas.core.dtypes.common import ( is_datetime64_dtype, is_datetime64tz_dtype, @@ -58,7 +56,6 @@ SQLAlchemyEngine, SQLDatabase, SQLiteDatabase, - _gt14, get_engine, pandasSQL_builder, read_sql_query, @@ -385,10 +382,10 @@ def mysql_pymysql_engine(iris_path, types_data): "mysql+pymysql://root@localhost:3306/pandas", connect_args={"client_flag": pymysql.constants.CLIENT.MULTI_STATEMENTS}, ) - check_target = sqlalchemy.inspect(engine) if _gt14() else engine - if not check_target.has_table("iris"): + insp = sqlalchemy.inspect(engine) + if not insp.has_table("iris"): create_and_load_iris(engine, iris_path, "mysql") - if not check_target.has_table("types"): + if not insp.has_table("types"): for entry in types_data: entry.pop("DateColWithTz") create_and_load_types(engine, types_data, "mysql") @@ -412,10 +409,10 @@ def postgresql_psycopg2_engine(iris_path, types_data): engine = sqlalchemy.create_engine( "postgresql+psycopg2://postgres:postgres@localhost:5432/pandas" ) - check_target = sqlalchemy.inspect(engine) if _gt14() else engine - if not check_target.has_table("iris"): + insp = sqlalchemy.inspect(engine) + if not insp.has_table("iris"): create_and_load_iris(engine, iris_path, "postgresql") - if not check_target.has_table("types"): + if not insp.has_table("types"): create_and_load_types(engine, types_data, "postgresql") yield engine with engine.connect() as conn: @@ -1425,14 +1422,6 @@ def test_database_uri_string(self, test_frame1): tm.assert_frame_equal(test_frame1, test_frame3) tm.assert_frame_equal(test_frame1, test_frame4) - @td.skip_if_installed("pg8000") - def test_pg8000_sqlalchemy_passthrough_error(self): - # using driver that will not be installed on CI to trigger error - # in sqlalchemy.create_engine -> test passing of this error to user - db_uri = "postgresql+pg8000://user:pass@host/dbname" - with pytest.raises(ImportError, match="pg8000"): - sql.read_sql("select * from table", db_uri) - def test_query_by_text_obj(self): # WIP : GH10846 from sqlalchemy import text @@ -1450,8 +1439,7 @@ def test_query_by_select_obj(self): ) iris = iris_table_metadata(self.flavor) - iris_select = iris if _gt14() else [iris] - name_select = select(iris_select).where(iris.c.Name == bindparam("name")) + name_select = select(iris).where(iris.c.Name == bindparam("name")) iris_df = sql.read_sql(name_select, self.conn, params={"name": "Iris-setosa"}) all_names = set(iris_df["Name"]) assert all_names == {"Iris-setosa"} @@ -1624,46 +1612,33 @@ def test_to_sql_empty(self, test_frame1): self._to_sql_empty(test_frame1) def test_create_table(self): + from sqlalchemy import inspect + temp_conn = self.connect() temp_frame = DataFrame( {"one": [1.0, 2.0, 3.0, 4.0], "two": [4.0, 3.0, 2.0, 1.0]} ) - pandasSQL = sql.SQLDatabase(temp_conn) assert pandasSQL.to_sql(temp_frame, "temp_frame") == 4 - if _gt14(): - from sqlalchemy import inspect - - insp = inspect(temp_conn) - assert insp.has_table("temp_frame") - else: - assert temp_conn.has_table("temp_frame") + insp = inspect(temp_conn) + assert insp.has_table("temp_frame") def test_drop_table(self): - temp_conn = self.connect() + from sqlalchemy import inspect + temp_conn = self.connect() temp_frame = DataFrame( {"one": [1.0, 2.0, 3.0, 4.0], "two": [4.0, 3.0, 2.0, 1.0]} ) - pandasSQL = sql.SQLDatabase(temp_conn) assert pandasSQL.to_sql(temp_frame, "temp_frame") == 4 - if _gt14(): - from sqlalchemy import inspect - - insp = inspect(temp_conn) - assert insp.has_table("temp_frame") - else: - assert temp_conn.has_table("temp_frame") + insp = inspect(temp_conn) + assert insp.has_table("temp_frame") pandasSQL.drop_table("temp_frame") - - if _gt14(): - assert not insp.has_table("temp_frame") - else: - assert not temp_conn.has_table("temp_frame") + assert not insp.has_table("temp_frame") def test_roundtrip(self, test_frame1): self._roundtrip(test_frame1) @@ -2156,14 +2131,10 @@ def bar(connection, data): data.to_sql(name="test_foo_data", con=connection, if_exists="append") def baz(conn): - if _gt14(): - # https://github.com/sqlalchemy/sqlalchemy/commit/ - # 00b5c10846e800304caa86549ab9da373b42fa5d#r48323973 - foo_data = foo(conn) - bar(conn, foo_data) - else: - foo_data = conn.run_callable(foo) - conn.run_callable(bar, foo_data) + # https://github.com/sqlalchemy/sqlalchemy/commit/ + # 00b5c10846e800304caa86549ab9da373b42fa5d#r48323973 + foo_data = foo(conn) + bar(conn, foo_data) def main(connectable): if isinstance(connectable, Engine): @@ -2216,14 +2187,9 @@ def test_temporary_table(self): ) from sqlalchemy.orm import ( Session, - sessionmaker, + declarative_base, ) - if _gt14(): - from sqlalchemy.orm import declarative_base - else: - from sqlalchemy.ext.declarative import declarative_base - test_data = "Hello, World!" expected = DataFrame({"spam": [test_data]}) Base = declarative_base() @@ -2234,24 +2200,13 @@ class Temporary(Base): id = Column(Integer, primary_key=True) spam = Column(Unicode(30), nullable=False) - if _gt14(): - with Session(self.conn) as session: - with session.begin(): - conn = session.connection() - Temporary.__table__.create(conn) - session.add(Temporary(spam=test_data)) - session.flush() - df = sql.read_sql_query(sql=select(Temporary.spam), con=conn) - else: - Session = sessionmaker() - session = Session(bind=self.conn) - with session.transaction: + with Session(self.conn) as session: + with session.begin(): conn = session.connection() Temporary.__table__.create(conn) session.add(Temporary(spam=test_data)) session.flush() - df = sql.read_sql_query(sql=select([Temporary.spam]), con=conn) - + df = sql.read_sql_query(sql=select(Temporary.spam), con=conn) tm.assert_frame_equal(df, expected) # -- SQL Engine tests (in the base class for now) @@ -2349,12 +2304,10 @@ def test_row_object_is_named_tuple(self): Integer, String, ) - from sqlalchemy.orm import sessionmaker - - if _gt14(): - from sqlalchemy.orm import declarative_base - else: - from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.orm import ( + declarative_base, + sessionmaker, + ) BaseModel = declarative_base() From ea90c45ea3d120297cce26b43370442f928e83cf Mon Sep 17 00:00:00 2001 From: Fangchen Li Date: Mon, 17 Jan 2022 11:37:44 -0600 Subject: [PATCH 2/2] revert change --- pandas/tests/io/test_sql.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index cdb26bc9d1825..741af4324c1a6 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -31,6 +31,8 @@ import numpy as np import pytest +import pandas.util._test_decorators as td + from pandas.core.dtypes.common import ( is_datetime64_dtype, is_datetime64tz_dtype, @@ -1422,6 +1424,14 @@ def test_database_uri_string(self, test_frame1): tm.assert_frame_equal(test_frame1, test_frame3) tm.assert_frame_equal(test_frame1, test_frame4) + @td.skip_if_installed("pg8000") + def test_pg8000_sqlalchemy_passthrough_error(self): + # using driver that will not be installed on CI to trigger error + # in sqlalchemy.create_engine -> test passing of this error to user + db_uri = "postgresql+pg8000://user:pass@host/dbname" + with pytest.raises(ImportError, match="pg8000"): + sql.read_sql("select * from table", db_uri) + def test_query_by_text_obj(self): # WIP : GH10846 from sqlalchemy import text