From 648240ba64dae6b8e59778ec48739501ffda7add Mon Sep 17 00:00:00 2001 From: wakabame Date: Thu, 8 Sep 2022 21:55:37 +0900 Subject: [PATCH 1/3] Add sqlalchemy --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index d7bea32fb..0aac884b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ odfpy = ">=1.4.1" xarray = ">=22.6.0" tabulate = ">=0.8.10" scipy = ">=1.9.1" +SQLAlchemy = "^1.4.41" [build-system] From 02ed6920f9bbd3a07a414e0a2be69ac064d6c4b5 Mon Sep 17 00:00:00 2001 From: wakabame Date: Thu, 8 Sep 2022 14:18:37 +0900 Subject: [PATCH 2/3] Add testcases using sqlalchemy engine and connection --- tests/test_io.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/test_io.py b/tests/test_io.py index 0e0f47b47..d09a9bff8 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -43,6 +43,7 @@ ) from pandas._testing import ensure_clean import pytest +import sqlalchemy from typing_extensions import assert_type from tests import check @@ -692,6 +693,31 @@ def test_read_sql(): con.close() +def test_read_sql_via_sqlalchemy_connection(): + with ensure_clean() as path: + db_uri = "sqlite:///" + path + engine = sqlalchemy.create_engine(db_uri) + + with engine.connect() as conn: + check(assert_type(DF.to_sql("test", con=conn), Union[int, None]), int) + check( + assert_type(read_sql("select * from test", con=conn), DataFrame), + DataFrame, + ) + + +def test_read_sql_via_sqlalchemy_engine(): + with ensure_clean() as path: + db_uri = "sqlite:///" + path + engine = sqlalchemy.create_engine(db_uri) + + check(assert_type(DF.to_sql("test", con=engine), Union[int, None]), int) + check( + assert_type(read_sql("select * from test", con=engine), DataFrame), + DataFrame, + ) + + def test_read_sql_generator(): with ensure_clean() as path: con = sqlite3.connect(path) From 3f8d44d45b6ffe95678d99bc3a9554a5ce5ddb2b Mon Sep 17 00:00:00 2001 From: wakabame Date: Thu, 8 Sep 2022 14:09:50 +0900 Subject: [PATCH 3/3] Accept sqlalchemy.engine.Engine --- pandas-stubs/core/generic.pyi | 2 +- pandas-stubs/io/sql.pyi | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pandas-stubs/core/generic.pyi b/pandas-stubs/core/generic.pyi index 0fe58fc43..312219b6d 100644 --- a/pandas-stubs/core/generic.pyi +++ b/pandas-stubs/core/generic.pyi @@ -160,7 +160,7 @@ class NDFrame(PandasObject, indexing.IndexingMixin): def to_sql( self, name: _str, - con: str | sqlalchemy.engine.Connection | sqlite3.Connection, + con: str | sqlalchemy.engine.Connectable | sqlite3.Connection, schema: _str | None = ..., if_exists: Literal["fail", "replace", "append"] = ..., index: _bool = ..., diff --git a/pandas-stubs/io/sql.pyi b/pandas-stubs/io/sql.pyi index 0f40c21e2..055736b2c 100644 --- a/pandas-stubs/io/sql.pyi +++ b/pandas-stubs/io/sql.pyi @@ -23,7 +23,7 @@ class DatabaseError(IOError): ... @overload def read_sql_table( table_name: str, - con: str | sqlalchemy.engine.Connection | sqlite3.Connection, + con: str | sqlalchemy.engine.Connectable | sqlite3.Connection, schema: str | None = ..., index_col: str | list[str] | None = ..., coerce_float: bool = ..., @@ -35,7 +35,7 @@ def read_sql_table( @overload def read_sql_table( table_name: str, - con: str | sqlalchemy.engine.Connection | sqlite3.Connection, + con: str | sqlalchemy.engine.Connectable | sqlite3.Connection, schema: str | None = ..., index_col: str | list[str] | None = ..., coerce_float: bool = ..., @@ -46,7 +46,7 @@ def read_sql_table( @overload def read_sql_query( sql: str, - con: str | sqlalchemy.engine.Connection | sqlite3.Connection, + con: str | sqlalchemy.engine.Connectable | sqlite3.Connection, index_col: str | list[str] | None = ..., coerce_float: bool = ..., params: list[str] | tuple[str, ...] | dict[str, str] | None = ..., @@ -58,7 +58,7 @@ def read_sql_query( @overload def read_sql_query( sql: str, - con: str | sqlalchemy.engine.Connection | sqlite3.Connection, + con: str | sqlalchemy.engine.Connectable | sqlite3.Connection, index_col: str | list[str] | None = ..., coerce_float: bool = ..., params: list[str] | tuple[str, ...] | dict[str, str] | None = ..., @@ -69,7 +69,7 @@ def read_sql_query( @overload def read_sql( sql: str, - con: str | sqlalchemy.engine.Connection | sqlite3.Connection, + con: str | sqlalchemy.engine.Connectable | sqlite3.Connection, index_col: str | list[str] | None = ..., coerce_float: bool = ..., params: list[str] | tuple[str, ...] | dict[str, str] | None = ..., @@ -81,7 +81,7 @@ def read_sql( @overload def read_sql( sql: str, - con: str | sqlalchemy.engine.Connection | sqlite3.Connection, + con: str | sqlalchemy.engine.Connectable | sqlite3.Connection, index_col: str | list[str] | None = ..., coerce_float: bool = ..., params: list[str] | tuple[str, ...] | dict[str, str] | None = ...,