diff --git a/pandas/io/tests/test_sql.py b/pandas/io/tests/test_sql.py index ba951c7cb513d..d95babff2653b 100644 --- a/pandas/io/tests/test_sql.py +++ b/pandas/io/tests/test_sql.py @@ -165,15 +165,64 @@ } +class MixInBase(object): + def tearDown(self): + for tbl in self._get_all_tables(): + self.drop_table(tbl) + self._close_conn() + + +class MySQLMixIn(MixInBase): + def drop_table(self, table_name): + cur = self.conn.cursor() + cur.execute("DROP TABLE IF EXISTS %s" % sql._get_valid_mysql_name(table_name)) + self.conn.commit() + + def _get_all_tables(self): + cur = self.conn.cursor() + cur.execute('SHOW TABLES') + return [table[0] for table in cur.fetchall()] + + def _close_conn(self): + from pymysql.err import Error + try: + self.conn.close() + except Error: + pass + + +class SQLiteMixIn(MixInBase): + def drop_table(self, table_name): + self.conn.execute("DROP TABLE IF EXISTS %s" % sql._get_valid_sqlite_name(table_name)) + self.conn.commit() + + def _get_all_tables(self): + c = self.conn.execute("SELECT name FROM sqlite_master WHERE type='table'") + return [table[0] for table in c.fetchall()] + + def _close_conn(self): + self.conn.close() + + +class SQLAlchemyMixIn(MixInBase): + def drop_table(self, table_name): + sql.SQLDatabase(self.conn).drop_table(table_name) + + def _get_all_tables(self): + meta = sqlalchemy.schema.MetaData(bind=self.conn) + meta.reflect() + table_list = meta.tables.keys() + return table_list + + def _close_conn(self): + pass + class PandasSQLTest(unittest.TestCase): """ Base class with common private methods for SQLAlchemy and fallback cases. """ - def drop_table(self, table_name): - self._get_exec().execute("DROP TABLE IF EXISTS %s" % table_name) - def _get_exec(self): if hasattr(self.conn, 'execute'): return self.conn @@ -768,7 +817,7 @@ def test_categorical(self): tm.assert_frame_equal(res, df) -class TestSQLApi(_TestSQLApi): +class TestSQLApi(SQLAlchemyMixIn, _TestSQLApi): """ Test the public API as it would be used directly @@ -889,13 +938,14 @@ def tearDown(self): self.conn.close() self.conn = self.__engine self.pandasSQL = sql.SQLDatabase(self.__engine) + super(_EngineToConnMixin, self).tearDown() class TestSQLApiConn(_EngineToConnMixin, TestSQLApi): pass -class TestSQLiteFallbackApi(_TestSQLApi): +class TestSQLiteFallbackApi(SQLiteMixIn, _TestSQLApi): """ Test the public sqlite connection fallback API @@ -978,7 +1028,7 @@ def test_sqlite_type_mapping(self): #--- Database flavor specific tests -class _TestSQLAlchemy(PandasSQLTest): +class _TestSQLAlchemy(SQLAlchemyMixIn, PandasSQLTest): """ Base class for testing the sqlalchemy backend. @@ -1451,10 +1501,6 @@ def setup_driver(cls): # sqlite3 is built-in cls.driver = None - def tearDown(self): - super(_TestSQLiteAlchemy, self).tearDown() - # in memory so tables should not be removed explicitly - def test_default_type_conversion(self): df = sql.read_sql_table("types_test_data", self.conn) @@ -1511,12 +1557,6 @@ def setup_driver(cls): except ImportError: raise nose.SkipTest('pymysql not installed') - def tearDown(self): - super(_TestMySQLAlchemy, self).tearDown() - c = self.conn.execute('SHOW TABLES') - for table in c.fetchall(): - self.conn.execute('DROP TABLE %s' % table[0]) - def test_default_type_conversion(self): df = sql.read_sql_table("types_test_data", self.conn) @@ -1586,14 +1626,6 @@ def setup_driver(cls): except ImportError: raise nose.SkipTest('psycopg2 not installed') - def tearDown(self): - super(_TestPostgreSQLAlchemy, self).tearDown() - c = self.conn.execute( - "SELECT table_name FROM information_schema.tables" - " WHERE table_schema = 'public'") - for table in c.fetchall(): - self.conn.execute("DROP TABLE %s" % table[0]) - def test_schema_support(self): # only test this for postgresql (schema's not supported in mysql/sqlite) df = DataFrame({'col1':[1, 2], 'col2':[0.1, 0.2], 'col3':['a', 'n']}) @@ -1694,7 +1726,7 @@ class TestSQLiteAlchemyConn(_TestSQLiteAlchemy, _TestSQLAlchemyConn): #------------------------------------------------------------------------------ #--- Test Sqlite / MySQL fallback -class TestSQLiteFallback(PandasSQLTest): +class TestSQLiteFallback(SQLiteMixIn, PandasSQLTest): """ Test the fallback mode against an in-memory sqlite database. @@ -1705,11 +1737,6 @@ class TestSQLiteFallback(PandasSQLTest): def connect(cls): return sqlite3.connect(':memory:') - def drop_table(self, table_name): - cur = self.conn.cursor() - cur.execute("DROP TABLE IF EXISTS %s" % table_name) - self.conn.commit() - def setUp(self): self.conn = self.connect() self.pandasSQL = sql.SQLiteDatabase(self.conn, 'sqlite') @@ -1856,7 +1883,7 @@ def test_illegal_names(self): for ndx, weird_name in enumerate(['test_weird_name]','test_weird_name[', 'test_weird_name`','test_weird_name"', 'test_weird_name\'', '_b.test_weird_name_01-30', '"_b.test_weird_name_01-30"', - '12345','12345blah']): + '99beginswithnumber', '12345']): df.to_sql(weird_name, self.conn, flavor=self.flavor) sql.table_exists(weird_name, self.conn) @@ -1866,7 +1893,7 @@ def test_illegal_names(self): sql.table_exists(c_tbl, self.conn) -class TestMySQLLegacy(TestSQLiteFallback): +class TestMySQLLegacy(MySQLMixIn, TestSQLiteFallback): """ Test the legacy mode against a MySQL database. @@ -1895,11 +1922,6 @@ def setup_driver(cls): def connect(cls): return cls.driver.connect(host='127.0.0.1', user='root', passwd='', db='pandas_nosetest') - def drop_table(self, table_name): - cur = self.conn.cursor() - cur.execute("DROP TABLE IF EXISTS %s" % table_name) - self.conn.commit() - def _count_rows(self, table_name): cur = self._get_exec() cur.execute( @@ -1918,14 +1940,6 @@ def setUp(self): self._load_iris_data() self._load_test1_data() - def tearDown(self): - c = self.conn.cursor() - c.execute('SHOW TABLES') - for table in c.fetchall(): - c.execute('DROP TABLE %s' % table[0]) - self.conn.commit() - self.conn.close() - def test_a_deprecation(self): with tm.assert_produces_warning(FutureWarning): sql.to_sql(self.test_frame1, 'test_frame1', self.conn, @@ -1963,14 +1977,10 @@ def test_illegal_names(self): for ndx, ok_name in enumerate(['99beginswithnumber','12345']): df.to_sql(ok_name, self.conn, flavor=self.flavor, index=False, if_exists='replace') - self.conn.cursor().execute("DROP TABLE `%s`" % ok_name) - self.conn.commit() df2 = DataFrame([[1, 2], [3, 4]], columns=['a', ok_name]) - c_tbl = 'test_ok_col_name%d'%ndx - df2.to_sql(c_tbl, self.conn, flavor=self.flavor, index=False, + + df2.to_sql('test_ok_col_name', self.conn, flavor=self.flavor, index=False, if_exists='replace') - self.conn.cursor().execute("DROP TABLE `%s`" % c_tbl) - self.conn.commit() # For MySQL, these should raise ValueError for ndx, illegal_name in enumerate(['test_illegal_name]','test_illegal_name[', @@ -1979,8 +1989,7 @@ def test_illegal_names(self): flavor=self.flavor, index=False) df2 = DataFrame([[1, 2], [3, 4]], columns=['a', illegal_name]) - c_tbl = 'test_illegal_col_name%d'%ndx - self.assertRaises(ValueError, df2.to_sql, c_tbl, + self.assertRaises(ValueError, df2.to_sql, 'test_illegal_col_name%d'%ndx, self.conn, flavor=self.flavor, index=False) @@ -2022,10 +2031,10 @@ def _skip_if_no_pymysql(): raise nose.SkipTest('pymysql not installed, skipping') -class TestXSQLite(tm.TestCase): +class TestXSQLite(SQLiteMixIn, tm.TestCase): def setUp(self): - self.db = sqlite3.connect(':memory:') + self.conn = sqlite3.connect(':memory:') def test_basic(self): frame = tm.makeTimeDataFrame() @@ -2036,34 +2045,34 @@ def test_write_row_by_row(self): frame = tm.makeTimeDataFrame() frame.ix[0, 0] = np.nan create_sql = sql.get_schema(frame, 'test', 'sqlite') - cur = self.db.cursor() + cur = self.conn.cursor() cur.execute(create_sql) - cur = self.db.cursor() + cur = self.conn.cursor() ins = "INSERT INTO test VALUES (%s, %s, %s, %s)" for idx, row in frame.iterrows(): fmt_sql = format_query(ins, *row) sql.tquery(fmt_sql, cur=cur) - self.db.commit() + self.conn.commit() - result = sql.read_frame("select * from test", con=self.db) + result = sql.read_frame("select * from test", con=self.conn) result.index = frame.index tm.assert_frame_equal(result, frame) def test_execute(self): frame = tm.makeTimeDataFrame() create_sql = sql.get_schema(frame, 'test', 'sqlite') - cur = self.db.cursor() + cur = self.conn.cursor() cur.execute(create_sql) ins = "INSERT INTO test VALUES (?, ?, ?, ?)" row = frame.ix[0] - sql.execute(ins, self.db, params=tuple(row)) - self.db.commit() + sql.execute(ins, self.conn, params=tuple(row)) + self.conn.commit() - result = sql.read_frame("select * from test", self.db) + result = sql.read_frame("select * from test", self.conn) result.index = frame.index[:1] tm.assert_frame_equal(result, frame[:1]) @@ -2080,7 +2089,7 @@ def test_schema(self): create_sql = sql.get_schema(frame, 'test', 'sqlite', keys=['A', 'B'],) lines = create_sql.splitlines() self.assertTrue('PRIMARY KEY ("A", "B")' in create_sql) - cur = self.db.cursor() + cur = self.conn.cursor() cur.execute(create_sql) def test_execute_fail(self): @@ -2093,17 +2102,17 @@ def test_execute_fail(self): PRIMARY KEY (a, b) ); """ - cur = self.db.cursor() + cur = self.conn.cursor() cur.execute(create_sql) - sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) - sql.execute('INSERT INTO test VALUES("foo", "baz", 2.567)', self.db) + sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.conn) + sql.execute('INSERT INTO test VALUES("foo", "baz", 2.567)', self.conn) try: sys.stdout = StringIO() self.assertRaises(Exception, sql.execute, 'INSERT INTO test VALUES("foo", "bar", 7)', - self.db) + self.conn) finally: sys.stdout = sys.__stdout__ @@ -2117,24 +2126,27 @@ def test_execute_closed_connection(self): PRIMARY KEY (a, b) ); """ - cur = self.db.cursor() + cur = self.conn.cursor() cur.execute(create_sql) - sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) - self.db.close() + sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.conn) + self.conn.close() try: sys.stdout = StringIO() self.assertRaises(Exception, sql.tquery, "select * from test", - con=self.db) + con=self.conn) finally: sys.stdout = sys.__stdout__ + # Initialize connection again (needed for tearDown) + self.setUp() + def test_na_roundtrip(self): pass def _check_roundtrip(self, frame): - sql.write_frame(frame, name='test_table', con=self.db) - result = sql.read_frame("select * from test_table", self.db) + sql.write_frame(frame, name='test_table', con=self.conn) + result = sql.read_frame("select * from test_table", self.conn) # HACK! Change this once indexes are handled properly. result.index = frame.index @@ -2145,8 +2157,8 @@ def _check_roundtrip(self, frame): frame['txt'] = ['a'] * len(frame) frame2 = frame.copy() frame2['Idx'] = Index(lrange(len(frame2))) + 10 - sql.write_frame(frame2, name='test_table2', con=self.db) - result = sql.read_frame("select * from test_table2", self.db, + sql.write_frame(frame2, name='test_table2', con=self.conn) + result = sql.read_frame("select * from test_table2", self.conn, index_col='Idx') expected = frame.copy() expected.index = Index(lrange(len(frame2))) + 10 @@ -2155,8 +2167,8 @@ def _check_roundtrip(self, frame): def test_tquery(self): frame = tm.makeTimeDataFrame() - sql.write_frame(frame, name='test_table', con=self.db) - result = sql.tquery("select A from test_table", self.db) + sql.write_frame(frame, name='test_table', con=self.conn) + result = sql.tquery("select A from test_table", self.conn) expected = Series(frame.A.values, frame.index) # not to have name result = Series(result, frame.index) tm.assert_series_equal(result, expected) @@ -2164,27 +2176,27 @@ def test_tquery(self): try: sys.stdout = StringIO() self.assertRaises(sql.DatabaseError, sql.tquery, - 'select * from blah', con=self.db) + 'select * from blah', con=self.conn) self.assertRaises(sql.DatabaseError, sql.tquery, - 'select * from blah', con=self.db, retry=True) + 'select * from blah', con=self.conn, retry=True) finally: sys.stdout = sys.__stdout__ def test_uquery(self): frame = tm.makeTimeDataFrame() - sql.write_frame(frame, name='test_table', con=self.db) + sql.write_frame(frame, name='test_table', con=self.conn) stmt = 'INSERT INTO test_table VALUES(2.314, -123.1, 1.234, 2.3)' - self.assertEqual(sql.uquery(stmt, con=self.db), 1) + self.assertEqual(sql.uquery(stmt, con=self.conn), 1) try: sys.stdout = StringIO() self.assertRaises(sql.DatabaseError, sql.tquery, - 'insert into blah values (1)', con=self.db) + 'insert into blah values (1)', con=self.conn) self.assertRaises(sql.DatabaseError, sql.tquery, - 'insert into blah values (1)', con=self.db, + 'insert into blah values (1)', con=self.conn, retry=True) finally: sys.stdout = sys.__stdout__ @@ -2193,16 +2205,16 @@ def test_keyword_as_column_names(self): ''' ''' df = DataFrame({'From':np.ones(5)}) - sql.write_frame(df, con = self.db, name = 'testkeywords') + sql.write_frame(df, con = self.conn, name = 'testkeywords') def test_onecolumn_of_integer(self): # GH 3628 # a column_of_integers dataframe should transfer well to sql mono_df=DataFrame([1 , 2], columns=['c0']) - sql.write_frame(mono_df, con = self.db, name = 'mono_df') + sql.write_frame(mono_df, con = self.conn, name = 'mono_df') # computing the sum via sql - con_x=self.db + con_x=self.conn the_sum=sum([my_c0[0] for my_c0 in con_x.execute("select * from mono_df")]) # it should not fail, and gives 3 ( Issue #3628 ) self.assertEqual(the_sum , 3) @@ -2221,56 +2233,53 @@ def clean_up(test_table_to_drop): Drops tables created from individual tests so no dependencies arise from sequential tests """ - if sql.table_exists(test_table_to_drop, self.db, flavor='sqlite'): - cur = self.db.cursor() - cur.execute("DROP TABLE %s" % test_table_to_drop) - cur.close() + self.drop_table(test_table_to_drop) # test if invalid value for if_exists raises appropriate error self.assertRaises(ValueError, sql.write_frame, frame=df_if_exists_1, - con=self.db, + con=self.conn, name=table_name, flavor='sqlite', if_exists='notvalidvalue') clean_up(table_name) # test if_exists='fail' - sql.write_frame(frame=df_if_exists_1, con=self.db, name=table_name, + sql.write_frame(frame=df_if_exists_1, con=self.conn, name=table_name, flavor='sqlite', if_exists='fail') self.assertRaises(ValueError, sql.write_frame, frame=df_if_exists_1, - con=self.db, + con=self.conn, name=table_name, flavor='sqlite', if_exists='fail') # test if_exists='replace' - sql.write_frame(frame=df_if_exists_1, con=self.db, name=table_name, + sql.write_frame(frame=df_if_exists_1, con=self.conn, name=table_name, flavor='sqlite', if_exists='replace') - self.assertEqual(sql.tquery(sql_select, con=self.db), + self.assertEqual(sql.tquery(sql_select, con=self.conn), [(1, 'A'), (2, 'B')]) - sql.write_frame(frame=df_if_exists_2, con=self.db, name=table_name, + sql.write_frame(frame=df_if_exists_2, con=self.conn, name=table_name, flavor='sqlite', if_exists='replace') - self.assertEqual(sql.tquery(sql_select, con=self.db), + self.assertEqual(sql.tquery(sql_select, con=self.conn), [(3, 'C'), (4, 'D'), (5, 'E')]) clean_up(table_name) # test if_exists='append' - sql.write_frame(frame=df_if_exists_1, con=self.db, name=table_name, + sql.write_frame(frame=df_if_exists_1, con=self.conn, name=table_name, flavor='sqlite', if_exists='fail') - self.assertEqual(sql.tquery(sql_select, con=self.db), + self.assertEqual(sql.tquery(sql_select, con=self.conn), [(1, 'A'), (2, 'B')]) - sql.write_frame(frame=df_if_exists_2, con=self.db, name=table_name, + sql.write_frame(frame=df_if_exists_2, con=self.conn, name=table_name, flavor='sqlite', if_exists='append') - self.assertEqual(sql.tquery(sql_select, con=self.db), + self.assertEqual(sql.tquery(sql_select, con=self.conn), [(1, 'A'), (2, 'B'), (3, 'C'), (4, 'D'), (5, 'E')]) clean_up(table_name) -class TestXMySQL(tm.TestCase): +class TestXMySQL(MySQLMixIn, tm.TestCase): @classmethod def setUpClass(cls): @@ -2307,14 +2316,14 @@ def setUp(self): try: # Try Travis defaults. # No real user should allow root access with a blank password. - self.db = pymysql.connect(host='localhost', user='root', passwd='', + self.conn = pymysql.connect(host='localhost', user='root', passwd='', db='pandas_nosetest') except: pass else: return try: - self.db = pymysql.connect(read_default_group='pandas') + self.conn = pymysql.connect(read_default_group='pandas') except pymysql.ProgrammingError as e: raise nose.SkipTest( "Create a group of connection parameters under the heading " @@ -2327,12 +2336,6 @@ def setUp(self): "[pandas] in your system's mysql default file, " "typically located at ~/.my.cnf or /etc/.my.cnf. ") - def tearDown(self): - from pymysql.err import Error - try: - self.db.close() - except Error: - pass def test_basic(self): _skip_if_no_pymysql() @@ -2346,7 +2349,7 @@ def test_write_row_by_row(self): frame.ix[0, 0] = np.nan drop_sql = "DROP TABLE IF EXISTS test" create_sql = sql.get_schema(frame, 'test', 'mysql') - cur = self.db.cursor() + cur = self.conn.cursor() cur.execute(drop_sql) cur.execute(create_sql) ins = "INSERT INTO test VALUES (%s, %s, %s, %s)" @@ -2354,9 +2357,9 @@ def test_write_row_by_row(self): fmt_sql = format_query(ins, *row) sql.tquery(fmt_sql, cur=cur) - self.db.commit() + self.conn.commit() - result = sql.read_frame("select * from test", con=self.db) + result = sql.read_frame("select * from test", con=self.conn) result.index = frame.index tm.assert_frame_equal(result, frame) @@ -2365,7 +2368,7 @@ def test_execute(self): frame = tm.makeTimeDataFrame() drop_sql = "DROP TABLE IF EXISTS test" create_sql = sql.get_schema(frame, 'test', 'mysql') - cur = self.db.cursor() + cur = self.conn.cursor() with warnings.catch_warnings(): warnings.filterwarnings("ignore", "Unknown table.*") cur.execute(drop_sql) @@ -2373,10 +2376,10 @@ def test_execute(self): ins = "INSERT INTO test VALUES (%s, %s, %s, %s)" row = frame.ix[0].values.tolist() - sql.execute(ins, self.db, params=tuple(row)) - self.db.commit() + sql.execute(ins, self.conn, params=tuple(row)) + self.conn.commit() - result = sql.read_frame("select * from test", self.db) + result = sql.read_frame("select * from test", self.conn) result.index = frame.index[:1] tm.assert_frame_equal(result, frame[:1]) @@ -2395,7 +2398,7 @@ def test_schema(self): create_sql = sql.get_schema(frame, 'test', 'mysql', keys=['A', 'B'],) lines = create_sql.splitlines() self.assertTrue('PRIMARY KEY (`A`, `B`)' in create_sql) - cur = self.db.cursor() + cur = self.conn.cursor() cur.execute(drop_sql) cur.execute(create_sql) @@ -2411,18 +2414,18 @@ def test_execute_fail(self): PRIMARY KEY (a(5), b(5)) ); """ - cur = self.db.cursor() + cur = self.conn.cursor() cur.execute(drop_sql) cur.execute(create_sql) - sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) - sql.execute('INSERT INTO test VALUES("foo", "baz", 2.567)', self.db) + sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.conn) + sql.execute('INSERT INTO test VALUES("foo", "baz", 2.567)', self.conn) try: sys.stdout = StringIO() self.assertRaises(Exception, sql.execute, 'INSERT INTO test VALUES("foo", "bar", 7)', - self.db) + self.conn) finally: sys.stdout = sys.__stdout__ @@ -2438,19 +2441,23 @@ def test_execute_closed_connection(self): PRIMARY KEY (a(5), b(5)) ); """ - cur = self.db.cursor() + cur = self.conn.cursor() cur.execute(drop_sql) cur.execute(create_sql) - sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) - self.db.close() + sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.conn) + self.conn.close() try: sys.stdout = StringIO() self.assertRaises(Exception, sql.tquery, "select * from test", - con=self.db) + con=self.conn) finally: sys.stdout = sys.__stdout__ + # Initialize connection again (needed for tearDown) + self.setUp() + + def test_na_roundtrip(self): _skip_if_no_pymysql() pass @@ -2458,12 +2465,12 @@ def test_na_roundtrip(self): def _check_roundtrip(self, frame): _skip_if_no_pymysql() drop_sql = "DROP TABLE IF EXISTS test_table" - cur = self.db.cursor() + cur = self.conn.cursor() with warnings.catch_warnings(): warnings.filterwarnings("ignore", "Unknown table.*") cur.execute(drop_sql) - sql.write_frame(frame, name='test_table', con=self.db, flavor='mysql') - result = sql.read_frame("select * from test_table", self.db) + sql.write_frame(frame, name='test_table', con=self.conn, flavor='mysql') + result = sql.read_frame("select * from test_table", self.conn) # HACK! Change this once indexes are handled properly. result.index = frame.index @@ -2477,12 +2484,12 @@ def _check_roundtrip(self, frame): index = Index(lrange(len(frame2))) + 10 frame2['Idx'] = index drop_sql = "DROP TABLE IF EXISTS test_table2" - cur = self.db.cursor() + cur = self.conn.cursor() with warnings.catch_warnings(): warnings.filterwarnings("ignore", "Unknown table.*") cur.execute(drop_sql) - sql.write_frame(frame2, name='test_table2', con=self.db, flavor='mysql') - result = sql.read_frame("select * from test_table2", self.db, + sql.write_frame(frame2, name='test_table2', con=self.conn, flavor='mysql') + result = sql.read_frame("select * from test_table2", self.conn, index_col='Idx') expected = frame.copy() @@ -2498,10 +2505,10 @@ def test_tquery(self): raise nose.SkipTest("no pymysql") frame = tm.makeTimeDataFrame() drop_sql = "DROP TABLE IF EXISTS test_table" - cur = self.db.cursor() + cur = self.conn.cursor() cur.execute(drop_sql) - sql.write_frame(frame, name='test_table', con=self.db, flavor='mysql') - result = sql.tquery("select A from test_table", self.db) + sql.write_frame(frame, name='test_table', con=self.conn, flavor='mysql') + result = sql.tquery("select A from test_table", self.conn) expected = Series(frame.A.values, frame.index) # not to have name result = Series(result, frame.index) tm.assert_series_equal(result, expected) @@ -2509,10 +2516,10 @@ def test_tquery(self): try: sys.stdout = StringIO() self.assertRaises(sql.DatabaseError, sql.tquery, - 'select * from blah', con=self.db) + 'select * from blah', con=self.conn) self.assertRaises(sql.DatabaseError, sql.tquery, - 'select * from blah', con=self.db, retry=True) + 'select * from blah', con=self.conn, retry=True) finally: sys.stdout = sys.__stdout__ @@ -2523,20 +2530,20 @@ def test_uquery(self): raise nose.SkipTest("no pymysql") frame = tm.makeTimeDataFrame() drop_sql = "DROP TABLE IF EXISTS test_table" - cur = self.db.cursor() + cur = self.conn.cursor() cur.execute(drop_sql) - sql.write_frame(frame, name='test_table', con=self.db, flavor='mysql') + sql.write_frame(frame, name='test_table', con=self.conn, flavor='mysql') stmt = 'INSERT INTO test_table VALUES(2.314, -123.1, 1.234, 2.3)' - self.assertEqual(sql.uquery(stmt, con=self.db), 1) + self.assertEqual(sql.uquery(stmt, con=self.conn), 1) try: sys.stdout = StringIO() self.assertRaises(sql.DatabaseError, sql.tquery, - 'insert into blah values (1)', con=self.db) + 'insert into blah values (1)', con=self.conn) self.assertRaises(sql.DatabaseError, sql.tquery, - 'insert into blah values (1)', con=self.db, + 'insert into blah values (1)', con=self.conn, retry=True) finally: sys.stdout = sys.__stdout__ @@ -2546,7 +2553,7 @@ def test_keyword_as_column_names(self): ''' _skip_if_no_pymysql() df = DataFrame({'From':np.ones(5)}) - sql.write_frame(df, con = self.db, name = 'testkeywords', + sql.write_frame(df, con = self.conn, name = 'testkeywords', if_exists='replace', flavor='mysql') def test_if_exists(self): @@ -2561,51 +2568,48 @@ def clean_up(test_table_to_drop): Drops tables created from individual tests so no dependencies arise from sequential tests """ - if sql.table_exists(test_table_to_drop, self.db, flavor='mysql'): - cur = self.db.cursor() - cur.execute("DROP TABLE %s" % test_table_to_drop) - cur.close() + self.drop_table(test_table_to_drop) # test if invalid value for if_exists raises appropriate error self.assertRaises(ValueError, sql.write_frame, frame=df_if_exists_1, - con=self.db, + con=self.conn, name=table_name, flavor='mysql', if_exists='notvalidvalue') clean_up(table_name) # test if_exists='fail' - sql.write_frame(frame=df_if_exists_1, con=self.db, name=table_name, + sql.write_frame(frame=df_if_exists_1, con=self.conn, name=table_name, flavor='mysql', if_exists='fail') self.assertRaises(ValueError, sql.write_frame, frame=df_if_exists_1, - con=self.db, + con=self.conn, name=table_name, flavor='mysql', if_exists='fail') # test if_exists='replace' - sql.write_frame(frame=df_if_exists_1, con=self.db, name=table_name, + sql.write_frame(frame=df_if_exists_1, con=self.conn, name=table_name, flavor='mysql', if_exists='replace') - self.assertEqual(sql.tquery(sql_select, con=self.db), + self.assertEqual(sql.tquery(sql_select, con=self.conn), [(1, 'A'), (2, 'B')]) - sql.write_frame(frame=df_if_exists_2, con=self.db, name=table_name, + sql.write_frame(frame=df_if_exists_2, con=self.conn, name=table_name, flavor='mysql', if_exists='replace') - self.assertEqual(sql.tquery(sql_select, con=self.db), + self.assertEqual(sql.tquery(sql_select, con=self.conn), [(3, 'C'), (4, 'D'), (5, 'E')]) clean_up(table_name) # test if_exists='append' - sql.write_frame(frame=df_if_exists_1, con=self.db, name=table_name, + sql.write_frame(frame=df_if_exists_1, con=self.conn, name=table_name, flavor='mysql', if_exists='fail') - self.assertEqual(sql.tquery(sql_select, con=self.db), + self.assertEqual(sql.tquery(sql_select, con=self.conn), [(1, 'A'), (2, 'B')]) - sql.write_frame(frame=df_if_exists_2, con=self.db, name=table_name, + sql.write_frame(frame=df_if_exists_2, con=self.conn, name=table_name, flavor='mysql', if_exists='append') - self.assertEqual(sql.tquery(sql_select, con=self.db), + self.assertEqual(sql.tquery(sql_select, con=self.conn), [(1, 'A'), (2, 'B'), (3, 'C'), (4, 'D'), (5, 'E')]) clean_up(table_name)