From 30c66798ac6aa1869a7a1239709483313eda72dd Mon Sep 17 00:00:00 2001 From: Artemy Kolchinsky Date: Mon, 15 Sep 2014 19:00:46 -0700 Subject: [PATCH] Avoid transaction context-managers for legacy SQL drivers (GH8277) Eliminating contextmanager based transaction-handling Rewriting as contextmanager Code review fixes Changing sqlalchemy version to 0.7.10 for python2.6 --- ci/requirements-2.6.txt | 2 +- pandas/io/sql.py | 38 ++++++++++++++++++++++--------------- pandas/io/tests/test_sql.py | 26 +++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 16 deletions(-) diff --git a/ci/requirements-2.6.txt b/ci/requirements-2.6.txt index baba82f588ed6..3a845f4ee0540 100644 --- a/ci/requirements-2.6.txt +++ b/ci/requirements-2.6.txt @@ -5,7 +5,7 @@ pytz==2013b http://www.crummy.com/software/BeautifulSoup/bs4/download/4.2/beautifulsoup4-4.2.0.tar.gz html5lib==1.0b2 numexpr==1.4.2 -sqlalchemy==0.7.4 +sqlalchemy==0.7.10 pymysql==0.6.0 psycopg2==2.5 scipy==0.11.0 diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 83b96d5186dd2..513ac1241ffdb 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -19,6 +19,7 @@ from pandas.core.base import PandasObject from pandas.tseries.tools import to_datetime +from contextlib import contextmanager class SQLAlchemyRequired(ImportError): pass @@ -637,13 +638,9 @@ def insert_data(self): return column_names, data_list - def get_session(self): - con = self.pd_sql.engine.connect() - return con.begin() - - def _execute_insert(self, trans, keys, data_iter): + def _execute_insert(self, conn, keys, data_iter): data = [dict( (k, v) for k, v in zip(keys, row) ) for row in data_iter] - trans.connection.execute(self.insert_statement(), data) + conn.execute(self.insert_statement(), data) def insert(self, chunksize=None): keys, data_list = self.insert_data() @@ -653,7 +650,7 @@ def insert(self, chunksize=None): chunksize = nrows chunks = int(nrows / chunksize) + 1 - with self.get_session() as trans: + with self.pd_sql.run_transaction() as conn: for i in range(chunks): start_i = i * chunksize end_i = min((i + 1) * chunksize, nrows) @@ -661,7 +658,7 @@ def insert(self, chunksize=None): break chunk_iter = zip(*[arr[start_i:end_i] for arr in data_list]) - self._execute_insert(trans, keys, chunk_iter) + self._execute_insert(conn, keys, chunk_iter) def read(self, coerce_float=True, parse_dates=None, columns=None): @@ -884,6 +881,9 @@ def __init__(self, engine, schema=None, meta=None): self.meta = meta + def run_transaction(self): + return self.engine.begin() + def execute(self, *args, **kwargs): """Simple passthrough to SQLAlchemy engine""" return self.engine.execute(*args, **kwargs) @@ -1017,9 +1017,9 @@ def sql_schema(self): return str(";\n".join(self.table)) def _execute_create(self): - with self.get_session(): + with self.pd_sql.run_transaction() as conn: for stmt in self.table: - self.pd_sql.execute(stmt) + conn.execute(stmt) def insert_statement(self): names = list(map(str, self.frame.columns)) @@ -1038,12 +1038,9 @@ def insert_statement(self): self.name, col_names, wildcards) return insert_statement - def get_session(self): - return self.pd_sql.con - - def _execute_insert(self, trans, keys, data_iter): + def _execute_insert(self, conn, keys, data_iter): data_list = list(data_iter) - trans.executemany(self.insert_statement(), data_list) + conn.executemany(self.insert_statement(), data_list) def _create_table_setup(self): """Return a list of SQL statement that create a table reflecting the @@ -1125,6 +1122,17 @@ def __init__(self, con, flavor, is_cursor=False): else: self.flavor = flavor + @contextmanager + def run_transaction(self): + cur = self.con.cursor() + try: + yield cur + self.con.commit() + except: + self.con.rollback() + finally: + cur.close() + def execute(self, *args, **kwargs): if self.is_cursor: cur = self.con diff --git a/pandas/io/tests/test_sql.py b/pandas/io/tests/test_sql.py index 80988ab2f5e1c..f02c701d97bcf 100644 --- a/pandas/io/tests/test_sql.py +++ b/pandas/io/tests/test_sql.py @@ -331,6 +331,28 @@ def _to_sql_save_index(self): ix_cols = self._get_index_columns('test_to_sql_saves_index') self.assertEqual(ix_cols, [['A',],]) + def _transaction_test(self): + self.pandasSQL.execute("CREATE TABLE test_trans (A INT, B TEXT)") + + ins_sql = "INSERT INTO test_trans (A,B) VALUES (1, 'blah')" + + # Make sure when transaction is rolled back, no rows get inserted + try: + with self.pandasSQL.run_transaction() as trans: + trans.execute(ins_sql) + raise Exception('error') + except: + # ignore raised exception + pass + res = self.pandasSQL.read_sql('SELECT * FROM test_trans') + self.assertEqual(len(res), 0) + + # Make sure when transaction is committed, rows do get inserted + with self.pandasSQL.run_transaction() as trans: + trans.execute(ins_sql) + res2 = self.pandasSQL.read_sql('SELECT * FROM test_trans') + self.assertEqual(len(res2), 1) + #------------------------------------------------------------------------------ #--- Testing the public API @@ -1072,6 +1094,8 @@ def _get_index_columns(self, tbl_name): def test_to_sql_save_index(self): self._to_sql_save_index() + def test_transactions(self): + self._transaction_test() class TestSQLiteAlchemy(_TestSQLAlchemy): """ @@ -1380,6 +1404,8 @@ def _get_index_columns(self, tbl_name): def test_to_sql_save_index(self): self._to_sql_save_index() + def test_transactions(self): + self._transaction_test() class TestMySQLLegacy(TestSQLiteLegacy): """