From 4610c9051795b33e18aa5fbb791773d08179cb68 Mon Sep 17 00:00:00 2001 From: Fangchen Li Date: Mon, 9 Aug 2021 17:56:17 -0500 Subject: [PATCH] TST: use fixtures for sql test data --- pandas/tests/io/test_sql.py | 337 ++++++++++++++++-------------------- 1 file changed, 154 insertions(+), 183 deletions(-) diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index e924bcef494b9..1ce1bac3b2b7b 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -222,6 +222,54 @@ } +@pytest.fixture +def test_frame1(): + columns = ["index", "A", "B", "C", "D"] + data = [ + ( + "2000-01-03 00:00:00", + 0.980268513777, + 3.68573087906, + -0.364216805298, + -1.15973806169, + ), + ( + "2000-01-04 00:00:00", + 1.04791624281, + -0.0412318367011, + -0.16181208307, + 0.212549316967, + ), + ( + "2000-01-05 00:00:00", + 0.498580885705, + 0.731167677815, + -0.537677223318, + 1.34627041952, + ), + ( + "2000-01-06 00:00:00", + 1.12020151869, + 1.56762092543, + 0.00364077397681, + 0.67525259227, + ), + ] + return DataFrame(data, columns=columns) + + +@pytest.fixture +def test_frame3(): + columns = ["index", "A", "B"] + data = [ + ("2000-01-03 00:00:00", 2 ** 31 - 1, -1.987670), + ("2000-01-04 00:00:00", -29, -0.0412318367011), + ("2000-01-05 00:00:00", 20000, 0.731167677815), + ("2000-01-06 00:00:00", -290867, 1.56762092543), + ] + return DataFrame(data, columns=columns) + + class MixInBase: def teardown_method(self, method): # if setup fails, there may not be a connection to close. @@ -323,66 +371,6 @@ def _check_iris_loaded_frame(self, iris_frame): assert issubclass(pytype, np.floating) tm.equalContents(row.values, [5.1, 3.5, 1.4, 0.2, "Iris-setosa"]) - def _load_test1_data(self): - columns = ["index", "A", "B", "C", "D"] - data = [ - ( - "2000-01-03 00:00:00", - 0.980268513777, - 3.68573087906, - -0.364216805298, - -1.15973806169, - ), - ( - "2000-01-04 00:00:00", - 1.04791624281, - -0.0412318367011, - -0.16181208307, - 0.212549316967, - ), - ( - "2000-01-05 00:00:00", - 0.498580885705, - 0.731167677815, - -0.537677223318, - 1.34627041952, - ), - ( - "2000-01-06 00:00:00", - 1.12020151869, - 1.56762092543, - 0.00364077397681, - 0.67525259227, - ), - ] - - self.test_frame1 = DataFrame(data, columns=columns) - - def _load_test2_data(self): - df = DataFrame( - { - "A": [4, 1, 3, 6], - "B": ["asd", "gsq", "ylt", "jkl"], - "C": [1.1, 3.1, 6.9, 5.3], - "D": [False, True, True, False], - "E": ["1990-11-22", "1991-10-26", "1993-11-26", "1995-12-12"], - } - ) - df["E"] = to_datetime(df["E"]) - - self.test_frame2 = df - - def _load_test3_data(self): - columns = ["index", "A", "B"] - data = [ - ("2000-01-03 00:00:00", 2 ** 31 - 1, -1.987670), - ("2000-01-04 00:00:00", -29, -0.0412318367011), - ("2000-01-05 00:00:00", 20000, 0.731167677815), - ("2000-01-06 00:00:00", -290867, 1.56762092543), - ] - - self.test_frame3 = DataFrame(data, columns=columns) - def _load_types_test_data(self, data): def _filter_to_flavor(flavor, df): flavor_dtypes = { @@ -498,66 +486,66 @@ def _read_sql_iris_no_parameter_with_percent(self): iris_frame = self.pandasSQL.read_query(query, params=None) self._check_iris_loaded_frame(iris_frame) - def _to_sql(self, method=None): + def _to_sql(self, test_frame1, method=None): self.drop_table("test_frame1") - self.pandasSQL.to_sql(self.test_frame1, "test_frame1", method=method) + self.pandasSQL.to_sql(test_frame1, "test_frame1", method=method) assert self.pandasSQL.has_table("test_frame1") - num_entries = len(self.test_frame1) + num_entries = len(test_frame1) num_rows = self._count_rows("test_frame1") assert num_rows == num_entries # Nuke table self.drop_table("test_frame1") - def _to_sql_empty(self): + def _to_sql_empty(self, test_frame1): self.drop_table("test_frame1") - self.pandasSQL.to_sql(self.test_frame1.iloc[:0], "test_frame1") + self.pandasSQL.to_sql(test_frame1.iloc[:0], "test_frame1") - def _to_sql_fail(self): + def _to_sql_fail(self, test_frame1): self.drop_table("test_frame1") - self.pandasSQL.to_sql(self.test_frame1, "test_frame1", if_exists="fail") + self.pandasSQL.to_sql(test_frame1, "test_frame1", if_exists="fail") assert self.pandasSQL.has_table("test_frame1") msg = "Table 'test_frame1' already exists" with pytest.raises(ValueError, match=msg): - self.pandasSQL.to_sql(self.test_frame1, "test_frame1", if_exists="fail") + self.pandasSQL.to_sql(test_frame1, "test_frame1", if_exists="fail") self.drop_table("test_frame1") - def _to_sql_replace(self): + def _to_sql_replace(self, test_frame1): self.drop_table("test_frame1") - self.pandasSQL.to_sql(self.test_frame1, "test_frame1", if_exists="fail") + self.pandasSQL.to_sql(test_frame1, "test_frame1", if_exists="fail") # Add to table again - self.pandasSQL.to_sql(self.test_frame1, "test_frame1", if_exists="replace") + self.pandasSQL.to_sql(test_frame1, "test_frame1", if_exists="replace") assert self.pandasSQL.has_table("test_frame1") - num_entries = len(self.test_frame1) + num_entries = len(test_frame1) num_rows = self._count_rows("test_frame1") assert num_rows == num_entries self.drop_table("test_frame1") - def _to_sql_append(self): + def _to_sql_append(self, test_frame1): # Nuke table just in case self.drop_table("test_frame1") - self.pandasSQL.to_sql(self.test_frame1, "test_frame1", if_exists="fail") + self.pandasSQL.to_sql(test_frame1, "test_frame1", if_exists="fail") # Add to table again - self.pandasSQL.to_sql(self.test_frame1, "test_frame1", if_exists="append") + self.pandasSQL.to_sql(test_frame1, "test_frame1", if_exists="append") assert self.pandasSQL.has_table("test_frame1") - num_entries = 2 * len(self.test_frame1) + num_entries = 2 * len(test_frame1) num_rows = self._count_rows("test_frame1") assert num_rows == num_entries self.drop_table("test_frame1") - def _to_sql_method_callable(self): + def _to_sql_method_callable(self, test_frame1): check = [] # used to double check function below is really being used def sample(pd_table, conn, keys, data_iter): @@ -567,36 +555,36 @@ def sample(pd_table, conn, keys, data_iter): self.drop_table("test_frame1") - self.pandasSQL.to_sql(self.test_frame1, "test_frame1", method=sample) + self.pandasSQL.to_sql(test_frame1, "test_frame1", method=sample) assert self.pandasSQL.has_table("test_frame1") assert check == [1] - num_entries = len(self.test_frame1) + num_entries = len(test_frame1) num_rows = self._count_rows("test_frame1") assert num_rows == num_entries # Nuke table self.drop_table("test_frame1") - def _to_sql_with_sql_engine(self, engine="auto", **engine_kwargs): + def _to_sql_with_sql_engine(self, test_frame1, engine="auto", **engine_kwargs): """`to_sql` with the `engine` param""" # mostly copied from this class's `_to_sql()` method self.drop_table("test_frame1") self.pandasSQL.to_sql( - self.test_frame1, "test_frame1", engine=engine, **engine_kwargs + test_frame1, "test_frame1", engine=engine, **engine_kwargs ) assert self.pandasSQL.has_table("test_frame1") - num_entries = len(self.test_frame1) + num_entries = len(test_frame1) num_rows = self._count_rows("test_frame1") assert num_rows == num_entries # Nuke table self.drop_table("test_frame1") - def _roundtrip(self): + def _roundtrip(self, test_frame1): self.drop_table("test_frame_roundtrip") - self.pandasSQL.to_sql(self.test_frame1, "test_frame_roundtrip") + self.pandasSQL.to_sql(test_frame1, "test_frame_roundtrip") result = self.pandasSQL.read_query("SELECT * FROM test_frame_roundtrip") result.set_index("level_0", inplace=True) @@ -604,7 +592,7 @@ def _roundtrip(self): result.index.name = None - tm.assert_frame_equal(result, self.test_frame1) + tm.assert_frame_equal(result, test_frame1) def _execute_sql(self): # drop_sql = "DROP TABLE IF EXISTS test" # should already be done @@ -679,9 +667,6 @@ def setup_method(self, load_iris_data): def load_test_data_and_sql(self): self._load_iris_view() - self._load_test1_data() - self._load_test2_data() - self._load_test3_data() self._load_raw_sql() def test_read_sql_iris(self): @@ -698,46 +683,46 @@ def test_read_sql_with_chunksize_no_result(self): without_batch = sql.read_sql_query(query, self.conn) tm.assert_frame_equal(concat(with_batch), without_batch) - def test_to_sql(self): - sql.to_sql(self.test_frame1, "test_frame1", self.conn) + def test_to_sql(self, test_frame1): + sql.to_sql(test_frame1, "test_frame1", self.conn) assert sql.has_table("test_frame1", self.conn) - def test_to_sql_fail(self): - sql.to_sql(self.test_frame1, "test_frame2", self.conn, if_exists="fail") + def test_to_sql_fail(self, test_frame1): + sql.to_sql(test_frame1, "test_frame2", self.conn, if_exists="fail") assert sql.has_table("test_frame2", self.conn) msg = "Table 'test_frame2' already exists" with pytest.raises(ValueError, match=msg): - sql.to_sql(self.test_frame1, "test_frame2", self.conn, if_exists="fail") + sql.to_sql(test_frame1, "test_frame2", self.conn, if_exists="fail") - def test_to_sql_replace(self): - sql.to_sql(self.test_frame1, "test_frame3", self.conn, if_exists="fail") + def test_to_sql_replace(self, test_frame1): + sql.to_sql(test_frame1, "test_frame3", self.conn, if_exists="fail") # Add to table again - sql.to_sql(self.test_frame1, "test_frame3", self.conn, if_exists="replace") + sql.to_sql(test_frame1, "test_frame3", self.conn, if_exists="replace") assert sql.has_table("test_frame3", self.conn) - num_entries = len(self.test_frame1) + num_entries = len(test_frame1) num_rows = self._count_rows("test_frame3") assert num_rows == num_entries - def test_to_sql_append(self): - sql.to_sql(self.test_frame1, "test_frame4", self.conn, if_exists="fail") + def test_to_sql_append(self, test_frame1): + sql.to_sql(test_frame1, "test_frame4", self.conn, if_exists="fail") # Add to table again - sql.to_sql(self.test_frame1, "test_frame4", self.conn, if_exists="append") + sql.to_sql(test_frame1, "test_frame4", self.conn, if_exists="append") assert sql.has_table("test_frame4", self.conn) - num_entries = 2 * len(self.test_frame1) + num_entries = 2 * len(test_frame1) num_rows = self._count_rows("test_frame4") assert num_rows == num_entries - def test_to_sql_type_mapping(self): - sql.to_sql(self.test_frame3, "test_frame5", self.conn, index=False) + def test_to_sql_type_mapping(self, test_frame3): + sql.to_sql(test_frame3, "test_frame5", self.conn, index=False) result = sql.read_sql("SELECT * FROM test_frame5", self.conn) - tm.assert_frame_equal(self.test_frame3, result) + tm.assert_frame_equal(test_frame3, result) def test_to_sql_series(self): s = Series(np.arange(5, dtype="int64"), name="series") @@ -745,27 +730,27 @@ def test_to_sql_series(self): s2 = sql.read_sql_query("SELECT * FROM test_series", self.conn) tm.assert_frame_equal(s.to_frame(), s2) - def test_roundtrip(self): - sql.to_sql(self.test_frame1, "test_frame_roundtrip", con=self.conn) + def test_roundtrip(self, test_frame1): + sql.to_sql(test_frame1, "test_frame_roundtrip", con=self.conn) result = sql.read_sql_query("SELECT * FROM test_frame_roundtrip", con=self.conn) # HACK! - result.index = self.test_frame1.index + result.index = test_frame1.index result.set_index("level_0", inplace=True) result.index.astype(int) result.index.name = None - tm.assert_frame_equal(result, self.test_frame1) + tm.assert_frame_equal(result, test_frame1) - def test_roundtrip_chunksize(self): + def test_roundtrip_chunksize(self, test_frame1): sql.to_sql( - self.test_frame1, + test_frame1, "test_frame_roundtrip", con=self.conn, index=False, chunksize=2, ) result = sql.read_sql_query("SELECT * FROM test_frame_roundtrip", con=self.conn) - tm.assert_frame_equal(result, self.test_frame1) + tm.assert_frame_equal(result, test_frame1) def test_execute_sql(self): # drop_sql = "DROP TABLE IF EXISTS test" # should already be done @@ -999,15 +984,13 @@ def test_integer_col_names(self): df = DataFrame([[1, 2], [3, 4]], columns=[0, 1]) sql.to_sql(df, "test_frame_integer_col_names", self.conn, if_exists="replace") - def test_get_schema(self): - create_sql = sql.get_schema(self.test_frame1, "test", con=self.conn) + def test_get_schema(self, test_frame1): + create_sql = sql.get_schema(test_frame1, "test", con=self.conn) assert "CREATE" in create_sql - def test_get_schema_with_schema(self): + def test_get_schema_with_schema(self, test_frame1): # GH28486 - create_sql = sql.get_schema( - self.test_frame1, "test", con=self.conn, schema="pypi" - ) + create_sql = sql.get_schema(test_frame1, "test", con=self.conn, schema="pypi") assert "CREATE TABLE pypi." in create_sql def test_get_schema_dtypes(self): @@ -1019,16 +1002,14 @@ def test_get_schema_dtypes(self): assert "CREATE" in create_sql assert "INTEGER" in create_sql - def test_get_schema_keys(self): + def test_get_schema_keys(self, test_frame1): frame = DataFrame({"Col1": [1.1, 1.2], "Col2": [2.1, 2.2]}) create_sql = sql.get_schema(frame, "test", con=self.conn, keys="Col1") constraint_sentence = 'CONSTRAINT test_pk PRIMARY KEY ("Col1")' assert constraint_sentence in create_sql # multiple columns as key (GH10385) - create_sql = sql.get_schema( - self.test_frame1, "test", con=self.conn, keys=["A", "B"] - ) + create_sql = sql.get_schema(test_frame1, "test", con=self.conn, keys=["A", "B"]) constraint_sentence = 'CONSTRAINT test_pk PRIMARY KEY ("A", "B")' assert constraint_sentence in create_sql @@ -1115,17 +1096,17 @@ class TestSQLApi(SQLAlchemyMixIn, _TestSQLApi): def connect(self): return sqlalchemy.create_engine("sqlite:///:memory:") - def test_read_table_columns(self): + def test_read_table_columns(self, test_frame1): # test columns argument in read_table - sql.to_sql(self.test_frame1, "test_frame", self.conn) + sql.to_sql(test_frame1, "test_frame", self.conn) cols = ["A", "B"] result = sql.read_sql_table("test_frame", self.conn, columns=cols) assert result.columns.tolist() == cols - def test_read_table_index_col(self): + def test_read_table_index_col(self, test_frame1): # test columns argument in read_table - sql.to_sql(self.test_frame1, "test_frame", self.conn) + sql.to_sql(test_frame1, "test_frame", self.conn) result = sql.read_sql_table("test_frame", self.conn, index_col="index") assert result.index.names == ["index"] @@ -1164,7 +1145,7 @@ def test_not_reflect_all_tables(self): # Verify some things assert len(w) == 0 - def test_warning_case_insensitive_table_name(self): + def test_warning_case_insensitive_table_name(self, test_frame1): # see gh-7815 # # We can't test that this warning is triggered, a the database @@ -1174,7 +1155,7 @@ def test_warning_case_insensitive_table_name(self): # Cause all warnings to always be triggered. warnings.simplefilter("always") # This should not trigger a Warning - self.test_frame1.to_sql("CaseSensitive", self.conn) + test_frame1.to_sql("CaseSensitive", self.conn) # Verify some things assert len(w) == 0 @@ -1236,10 +1217,8 @@ def test_sqlalchemy_integer_overload_mapping(self, integer): ): sql.SQLTable("test_type", db, frame=df) - def test_database_uri_string(self): - + def test_database_uri_string(self, test_frame1): # Test read_sql and .to_sql method with a database URI (GH10654) - test_frame1 = self.test_frame1 # db_uri = 'sqlite:///:memory:' # raises # sqlalchemy.exc.OperationalError: (sqlite3.OperationalError) near # "iris": syntax error [SQL: 'iris'] @@ -1353,21 +1332,21 @@ class TestSQLiteFallbackApi(SQLiteMixIn, _TestSQLApi): def connect(self, database=":memory:"): return sqlite3.connect(database) - def test_sql_open_close(self): + def test_sql_open_close(self, test_frame3): # Test if the IO in the database still work if the connection closed # between the writing and reading (as in many real situations). with tm.ensure_clean() as name: conn = self.connect(name) - sql.to_sql(self.test_frame3, "test_frame3_legacy", conn, index=False) + sql.to_sql(test_frame3, "test_frame3_legacy", conn, index=False) conn.close() conn = self.connect(name) result = sql.read_sql_query("SELECT * FROM test_frame3_legacy;", conn) conn.close() - tm.assert_frame_equal(self.test_frame3, result) + tm.assert_frame_equal(test_frame3, result) @pytest.mark.skipif(SQLALCHEMY_INSTALLED, reason="SQLAlchemy is installed") def test_con_string_import_error(self): @@ -1391,9 +1370,9 @@ def test_safe_names_warning(self): with tm.assert_produces_warning(): sql.to_sql(df, "test_frame3_legacy", self.conn, index=False) - def test_get_schema2(self): + def test_get_schema2(self, test_frame1): # without providing a connection object (available for backwards comp) - create_sql = sql.get_schema(self.test_frame1, "test") + create_sql = sql.get_schema(test_frame1, "test") assert "CREATE" in create_sql def _get_sqlite_column_type(self, schema, column): @@ -1439,7 +1418,6 @@ def setup_class(cls): def load_test_data_and_sql(self): self._load_raw_sql() - self._load_test1_data() @pytest.fixture(autouse=True) def setup_method(self, load_iris_data): @@ -1477,26 +1455,26 @@ def test_read_sql_parameter(self): def test_read_sql_named_parameter(self): self._read_sql_iris_named_parameter() - def test_to_sql(self): - self._to_sql() + def test_to_sql(self, test_frame1): + self._to_sql(test_frame1) - def test_to_sql_empty(self): - self._to_sql_empty() + def test_to_sql_empty(self, test_frame1): + self._to_sql_empty(test_frame1) - def test_to_sql_fail(self): - self._to_sql_fail() + def test_to_sql_fail(self, test_frame1): + self._to_sql_fail(test_frame1) - def test_to_sql_replace(self): - self._to_sql_replace() + def test_to_sql_replace(self, test_frame1): + self._to_sql_replace(test_frame1) - def test_to_sql_append(self): - self._to_sql_append() + def test_to_sql_append(self, test_frame1): + self._to_sql_append(test_frame1) - def test_to_sql_method_multi(self): - self._to_sql(method="multi") + def test_to_sql_method_multi(self, test_frame1): + self._to_sql(test_frame1, method="multi") - def test_to_sql_method_callable(self): - self._to_sql_method_callable() + def test_to_sql_method_callable(self, test_frame1): + self._to_sql_method_callable(test_frame1) def test_create_table(self): temp_conn = self.connect() @@ -1536,8 +1514,8 @@ def test_drop_table(self): else: assert not temp_conn.has_table("temp_frame") - def test_roundtrip(self): - self._roundtrip() + def test_roundtrip(self, test_frame1): + self._roundtrip(test_frame1) def test_execute_sql(self): self._execute_sql() @@ -1888,15 +1866,14 @@ def test_to_sql_save_index(self): def test_transactions(self): self._transaction_test() - def test_get_schema_create_table(self): + def test_get_schema_create_table(self, test_frame3): # Use a dataframe without a bool column, since MySQL converts bool to # TINYINT (which read_sql_table returns as an int and causes a dtype # mismatch) - self._load_test3_data() tbl = "test_get_schema_create_table" - create_sql = sql.get_schema(self.test_frame3, tbl, con=self.conn) - blank_test_df = self.test_frame3.iloc[:0] + create_sql = sql.get_schema(test_frame3, tbl, con=self.conn) + blank_test_df = test_frame3.iloc[:0] self.drop_table(tbl) self.conn.execute(create_sql) @@ -2072,22 +2049,20 @@ class Temporary(Base): tm.assert_frame_equal(df, expected) # -- SQL Engine tests (in the base class for now) - def test_invalid_engine(self): + def test_invalid_engine(self, test_frame1): msg = "engine must be one of 'auto', 'sqlalchemy'" with pytest.raises(ValueError, match=msg): - self._to_sql_with_sql_engine("bad_engine") + self._to_sql_with_sql_engine(test_frame1, "bad_engine") - def test_options_sqlalchemy(self): + def test_options_sqlalchemy(self, test_frame1): # use the set option - with pd.option_context("io.sql.engine", "sqlalchemy"): - self._to_sql_with_sql_engine() + self._to_sql_with_sql_engine(test_frame1) - def test_options_auto(self): + def test_options_auto(self, test_frame1): # use the set option - with pd.option_context("io.sql.engine", "auto"): - self._to_sql_with_sql_engine() + self._to_sql_with_sql_engine(test_frame1) def test_options_get_engine(self): assert isinstance(get_engine("sqlalchemy"), SQLAlchemyEngine) @@ -2405,13 +2380,9 @@ def connect(cls): def setup_connect(self): self.conn = self.connect() - def load_test_data_and_sql(self): - self.pandasSQL = sql.SQLiteDatabase(self.conn) - self._load_test1_data() - @pytest.fixture(autouse=True) def setup_method(self, load_iris_data): - self.load_test_data_and_sql() + self.pandasSQL = sql.SQLiteDatabase(self.conn) def test_read_sql(self): self._read_sql_iris() @@ -2422,24 +2393,24 @@ def test_read_sql_parameter(self): def test_read_sql_named_parameter(self): self._read_sql_iris_named_parameter() - def test_to_sql(self): - self._to_sql() + def test_to_sql(self, test_frame1): + self._to_sql(test_frame1) - def test_to_sql_empty(self): - self._to_sql_empty() + def test_to_sql_empty(self, test_frame1): + self._to_sql_empty(test_frame1) - def test_to_sql_fail(self): - self._to_sql_fail() + def test_to_sql_fail(self, test_frame1): + self._to_sql_fail(test_frame1) - def test_to_sql_replace(self): - self._to_sql_replace() + def test_to_sql_replace(self, test_frame1): + self._to_sql_replace(test_frame1) - def test_to_sql_append(self): - self._to_sql_append() + def test_to_sql_append(self, test_frame1): + self._to_sql_append(test_frame1) - def test_to_sql_method_multi(self): + def test_to_sql_method_multi(self, test_frame1): # GH 29921 - self._to_sql(method="multi") + self._to_sql(test_frame1, method="multi") def test_create_and_drop_table(self): temp_frame = DataFrame( @@ -2454,8 +2425,8 @@ def test_create_and_drop_table(self): assert not self.pandasSQL.has_table("drop_test_frame") - def test_roundtrip(self): - self._roundtrip() + def test_roundtrip(self, test_frame1): + self._roundtrip(test_frame1) def test_execute_sql(self): self._execute_sql()