Skip to content

Commit 28fd624

Browse files
authored
Accept sqlalchemy.engine.Engine for SQL IO API (read_sql, to_sql) (#281)
* Add sqlalchemy * Add testcases using sqlalchemy engine and connection * Accept sqlalchemy.engine.Engine
1 parent 096b9b8 commit 28fd624

File tree

4 files changed

+34
-7
lines changed

4 files changed

+34
-7
lines changed

pandas-stubs/core/generic.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ class NDFrame(PandasObject, indexing.IndexingMixin):
160160
def to_sql(
161161
self,
162162
name: _str,
163-
con: str | sqlalchemy.engine.Connection | sqlite3.Connection,
163+
con: str | sqlalchemy.engine.Connectable | sqlite3.Connection,
164164
schema: _str | None = ...,
165165
if_exists: Literal["fail", "replace", "append"] = ...,
166166
index: _bool = ...,

pandas-stubs/io/sql.pyi

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class DatabaseError(IOError): ...
2323
@overload
2424
def read_sql_table(
2525
table_name: str,
26-
con: str | sqlalchemy.engine.Connection | sqlite3.Connection,
26+
con: str | sqlalchemy.engine.Connectable | sqlite3.Connection,
2727
schema: str | None = ...,
2828
index_col: str | list[str] | None = ...,
2929
coerce_float: bool = ...,
@@ -35,7 +35,7 @@ def read_sql_table(
3535
@overload
3636
def read_sql_table(
3737
table_name: str,
38-
con: str | sqlalchemy.engine.Connection | sqlite3.Connection,
38+
con: str | sqlalchemy.engine.Connectable | sqlite3.Connection,
3939
schema: str | None = ...,
4040
index_col: str | list[str] | None = ...,
4141
coerce_float: bool = ...,
@@ -46,7 +46,7 @@ def read_sql_table(
4646
@overload
4747
def read_sql_query(
4848
sql: str,
49-
con: str | sqlalchemy.engine.Connection | sqlite3.Connection,
49+
con: str | sqlalchemy.engine.Connectable | sqlite3.Connection,
5050
index_col: str | list[str] | None = ...,
5151
coerce_float: bool = ...,
5252
params: list[str] | tuple[str, ...] | dict[str, str] | None = ...,
@@ -58,7 +58,7 @@ def read_sql_query(
5858
@overload
5959
def read_sql_query(
6060
sql: str,
61-
con: str | sqlalchemy.engine.Connection | sqlite3.Connection,
61+
con: str | sqlalchemy.engine.Connectable | sqlite3.Connection,
6262
index_col: str | list[str] | None = ...,
6363
coerce_float: bool = ...,
6464
params: list[str] | tuple[str, ...] | dict[str, str] | None = ...,
@@ -69,7 +69,7 @@ def read_sql_query(
6969
@overload
7070
def read_sql(
7171
sql: str,
72-
con: str | sqlalchemy.engine.Connection | sqlite3.Connection,
72+
con: str | sqlalchemy.engine.Connectable | sqlite3.Connection,
7373
index_col: str | list[str] | None = ...,
7474
coerce_float: bool = ...,
7575
params: list[str] | tuple[str, ...] | dict[str, str] | None = ...,
@@ -81,7 +81,7 @@ def read_sql(
8181
@overload
8282
def read_sql(
8383
sql: str,
84-
con: str | sqlalchemy.engine.Connection | sqlite3.Connection,
84+
con: str | sqlalchemy.engine.Connectable | sqlite3.Connection,
8585
index_col: str | list[str] | None = ...,
8686
coerce_float: bool = ...,
8787
params: list[str] | tuple[str, ...] | dict[str, str] | None = ...,

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ odfpy = ">=1.4.1"
5757
xarray = ">=22.6.0"
5858
tabulate = ">=0.8.10"
5959
scipy = ">=1.9.1"
60+
SQLAlchemy = "^1.4.41"
6061

6162

6263
[build-system]

tests/test_io.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
)
4444
from pandas._testing import ensure_clean
4545
import pytest
46+
import sqlalchemy
4647
from typing_extensions import assert_type
4748

4849
from tests import check
@@ -692,6 +693,31 @@ def test_read_sql():
692693
con.close()
693694

694695

696+
def test_read_sql_via_sqlalchemy_connection():
697+
with ensure_clean() as path:
698+
db_uri = "sqlite:///" + path
699+
engine = sqlalchemy.create_engine(db_uri)
700+
701+
with engine.connect() as conn:
702+
check(assert_type(DF.to_sql("test", con=conn), Union[int, None]), int)
703+
check(
704+
assert_type(read_sql("select * from test", con=conn), DataFrame),
705+
DataFrame,
706+
)
707+
708+
709+
def test_read_sql_via_sqlalchemy_engine():
710+
with ensure_clean() as path:
711+
db_uri = "sqlite:///" + path
712+
engine = sqlalchemy.create_engine(db_uri)
713+
714+
check(assert_type(DF.to_sql("test", con=engine), Union[int, None]), int)
715+
check(
716+
assert_type(read_sql("select * from test", con=engine), DataFrame),
717+
DataFrame,
718+
)
719+
720+
695721
def test_read_sql_generator():
696722
with ensure_clean() as path:
697723
con = sqlite3.connect(path)

0 commit comments

Comments
 (0)