Skip to content

Commit a1bd004

Browse files
BUG/TST: fix get_schema + add tests
- adapt signature of get_schema to what it was before (for backwards compatibility), possibility to call function without connection object - fix get_schema for sqlalchemy mode - add tests
1 parent 7aad61f commit a1bd004

File tree

2 files changed

+72
-10
lines changed

2 files changed

+72
-10
lines changed

pandas/io/sql.py

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,8 @@ def exists(self):
484484
return self.pd_sql.has_table(self.name)
485485

486486
def sql_schema(self):
487-
return str(self.table.compile())
487+
from sqlalchemy.schema import CreateTable
488+
return str(CreateTable(self.table))
488489

489490
def create(self):
490491
self.table.create()
@@ -782,7 +783,7 @@ def drop_table(self, table_name):
782783

783784
def _create_sql_schema(self, frame, table_name):
784785
table = PandasSQLTable(table_name, self, frame=frame)
785-
return str(table.compile())
786+
return str(table.sql_schema())
786787

787788

788789
# ---- SQL without SQLAlchemy ---
@@ -1028,27 +1029,78 @@ def drop_table(self, name):
10281029
drop_sql = "DROP TABLE %s" % name
10291030
self.execute(drop_sql)
10301031

1032+
def _create_sql_schema(self, frame, table_name):
1033+
table = PandasSQLTableLegacy(table_name, self, frame=frame)
1034+
return str(table.sql_schema())
10311035

1032-
# legacy names, with depreciation warnings and copied docs
1033-
def get_schema(frame, name, con, flavor='sqlite'):
1036+
1037+
def get_schema(frame, name, flavor='sqlite', keys=None, con=None):
10341038
"""
10351039
Get the SQL db table schema for the given frame
10361040
10371041
Parameters
10381042
----------
1039-
frame: DataFrame
1040-
name: name of SQL table
1041-
con: an open SQL database connection object
1042-
engine: an SQLAlchemy engine - replaces connection and flavor
1043-
flavor: {'sqlite', 'mysql'}, default 'sqlite'
1043+
frame : DataFrame
1044+
name : name of SQL table
1045+
flavor : {'sqlite', 'mysql'}, default 'sqlite'
1046+
keys : columns to use a primary key
1047+
con: an open SQL database connection object or an SQLAlchemy engine
10441048
10451049
"""
1046-
warnings.warn("get_schema is depreciated", FutureWarning)
1050+
1051+
if con is None:
1052+
return _get_schema_legacy(frame, name, flavor, keys)
10471053

10481054
pandas_sql = pandasSQL_builder(con=con, flavor=flavor)
10491055
return pandas_sql._create_sql_schema(frame, name)
10501056

10511057

1058+
def _get_schema_legacy(frame, name, flavor, keys=None):
1059+
"""Old function from 0.13.1. To keep backwards compatibility.
1060+
When mysql legacy support is dropped, it should be possible to
1061+
remove this code
1062+
"""
1063+
1064+
def get_sqltype(dtype, flavor):
1065+
pytype = dtype.type
1066+
pytype_name = "text"
1067+
if issubclass(pytype, np.floating):
1068+
pytype_name = "float"
1069+
elif issubclass(pytype, np.integer):
1070+
pytype_name = "int"
1071+
elif issubclass(pytype, np.datetime64) or pytype is datetime:
1072+
# Caution: np.datetime64 is also a subclass of np.number.
1073+
pytype_name = "datetime"
1074+
elif pytype is datetime.date:
1075+
pytype_name = "date"
1076+
elif issubclass(pytype, np.bool_):
1077+
pytype_name = "bool"
1078+
1079+
return _SQL_TYPES[pytype_name][flavor]
1080+
1081+
lookup_type = lambda dtype: get_sqltype(dtype, flavor)
1082+
1083+
column_types = lzip(frame.dtypes.index, map(lookup_type, frame.dtypes))
1084+
if flavor == 'sqlite':
1085+
columns = ',\n '.join('[%s] %s' % x for x in column_types)
1086+
else:
1087+
columns = ',\n '.join('`%s` %s' % x for x in column_types)
1088+
1089+
keystr = ''
1090+
if keys is not None:
1091+
if isinstance(keys, string_types):
1092+
keys = (keys,)
1093+
keystr = ', PRIMARY KEY (%s)' % ','.join(keys)
1094+
template = """CREATE TABLE %(name)s (
1095+
%(columns)s
1096+
%(keystr)s
1097+
);"""
1098+
create_statement = template % {'name': name, 'columns': columns,
1099+
'keystr': keystr}
1100+
return create_statement
1101+
1102+
1103+
# legacy names, with depreciation warnings and copied docs
10521104

10531105
def read_frame(*args, **kwargs):
10541106
"""DEPRECIATED - use read_sql

pandas/io/tests/test_sql.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,11 @@ def test_integer_col_names(self):
565565
sql.to_sql(df, "test_frame_integer_col_names", self.conn,
566566
if_exists='replace')
567567

568+
def test_get_schema(self):
569+
create_sql = sql.get_schema(self.test_frame1, 'test', 'sqlite',
570+
con=self.conn)
571+
self.assert_('CREATE' in create_sql)
572+
568573

569574
class TestSQLApi(_TestSQLApi):
570575
"""
@@ -684,6 +689,11 @@ def test_safe_names_warning(self):
684689
sql.to_sql(df, "test_frame3_legacy", self.conn,
685690
flavor="sqlite", index=False)
686691

692+
def test_get_schema2(self):
693+
# without providing a connection object (available for backwards comp)
694+
create_sql = sql.get_schema(self.test_frame1, 'test', 'sqlite')
695+
self.assert_('CREATE' in create_sql)
696+
687697

688698
#------------------------------------------------------------------------------
689699
#--- Database flavor specific tests

0 commit comments

Comments
 (0)