From e9f1868c4616e3ef1c1e3b5842bb7b4cc5ca4475 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Thu, 30 Dec 2021 20:07:36 -0800 Subject: [PATCH 1/8] ENH: to_sql returns rowcount --- doc/source/whatsnew/v1.4.0.rst | 2 +- pandas/core/generic.py | 18 ++- pandas/io/sql.py | 72 +++++++---- pandas/tests/io/test_sql.py | 229 ++++++++++++++++++++------------- 4 files changed, 205 insertions(+), 116 deletions(-) diff --git a/doc/source/whatsnew/v1.4.0.rst b/doc/source/whatsnew/v1.4.0.rst index 3924191bebcfd..25ca8506d53ad 100644 --- a/doc/source/whatsnew/v1.4.0.rst +++ b/doc/source/whatsnew/v1.4.0.rst @@ -236,7 +236,7 @@ Other enhancements - :meth:`is_list_like` now identifies duck-arrays as list-like unless ``.ndim == 0`` (:issue:`35131`) - :class:`ExtensionDtype` and :class:`ExtensionArray` are now (de)serialized when exporting a :class:`DataFrame` with :meth:`DataFrame.to_json` using ``orient='table'`` (:issue:`20612`, :issue:`44705`). - Add support for `Zstandard `_ compression to :meth:`DataFrame.to_pickle`/:meth:`read_pickle` and friends (:issue:`43925`) -- +- :meth:`DataFrame.to_sql` now returns an ``int`` of the number of written rows (:issue:`23998`) .. --------------------------------------------------------------------------- diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 0034d0511c15e..4235a14c6c63a 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -2820,6 +2820,20 @@ def to_sql( Details and a sample callable implementation can be found in the section :ref:`insert method `. + Returns + ------- + None or int + Number of rows affected by to_sql. None is returned if the callable + passed into ``method`` does not return the number of rows. + + The number of returned rows affected is the sum of the ``rowcount`` + attribute of ``sqlite3.Cursor`` or SQLAlchemy connectable which may not + reflect the exact number of written rows as stipulated in the + `sqlite3 `__ or + `SQLAlchemy `__ + + .. versionadded:: 1.4.0 + Raises ------ ValueError @@ -2903,10 +2917,10 @@ def to_sql( >>> engine.execute("SELECT * FROM integers").fetchall() [(1,), (None,), (2,)] - """ + """ # noqa: 501 from pandas.io import sql - sql.to_sql( + return sql.to_sql( self, name, con, diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 548bd617a285f..f70bb03b3ff2f 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -611,7 +611,7 @@ def to_sql( method: str | None = None, engine: str = "auto", **engine_kwargs, -) -> None: +) -> int | None: """ Write records stored in a DataFrame to a SQL database. @@ -650,8 +650,8 @@ def to_sql( Controls the SQL insertion clause used: - None : Uses standard SQL ``INSERT`` clause (one per row). - - 'multi': Pass multiple values in a single ``INSERT`` clause. - - callable with signature ``(pd_table, conn, keys, data_iter)``. + - ``'multi'``: Pass multiple values in a single ``INSERT`` clause. + - callable with signature ``(pd_table, conn, keys, data_iter) -> int | None``. Details and a sample callable implementation can be found in the section :ref:`insert method `. @@ -664,7 +664,23 @@ def to_sql( **engine_kwargs Any additional kwargs are passed to the engine. - """ + + Returns + ------- + None or int + Number of rows affected by to_sql. None is returned if the callable + passed into ``method`` does not return the number of rows. + + .. versionadded:: 1.4.0 + + Notes + ----- + The returned rows affected is the sum of the ``rowcount`` attribute of ``sqlite3.Cursor`` + or SQLAlchemy connectable. The returned value may not reflect the exact number of written + rows as stipulated in the + `sqlite3 `__ or + `SQLAlchemy `__ + """ # noqa: 501 if if_exists not in ("fail", "replace", "append"): raise ValueError(f"'{if_exists}' is not valid for if_exists") @@ -677,7 +693,7 @@ def to_sql( "'frame' argument should be either a Series or a DataFrame" ) - pandas_sql.to_sql( + return pandas_sql.to_sql( frame, name, if_exists=if_exists, @@ -817,7 +833,7 @@ def create(self): else: self._execute_create() - def _execute_insert(self, conn, keys: list[str], data_iter): + def _execute_insert(self, conn, keys: list[str], data_iter) -> int: """ Execute SQL statement inserting data @@ -830,9 +846,10 @@ def _execute_insert(self, conn, keys: list[str], data_iter): Each item contains a list of values to be inserted """ data = [dict(zip(keys, row)) for row in data_iter] - conn.execute(self.table.insert(), data) + result = conn.execute(self.table.insert(), data) + return result.rowcount - def _execute_insert_multi(self, conn, keys: list[str], data_iter): + def _execute_insert_multi(self, conn, keys: list[str], data_iter) -> int: """ Alternative to _execute_insert for DBs support multivalue INSERT. @@ -846,6 +863,7 @@ def _execute_insert_multi(self, conn, keys: list[str], data_iter): data = [dict(zip(keys, row)) for row in data_iter] stmt = insert(self.table).values(data) conn.execute(stmt) + return conn.rowcount def insert_data(self): if self.index is not None: @@ -885,7 +903,9 @@ def insert_data(self): return column_names, data_list - def insert(self, chunksize: int | None = None, method: str | None = None): + def insert( + self, chunksize: int | None = None, method: str | None = None + ) -> int | None: # set insert method if method is None: @@ -902,7 +922,7 @@ def insert(self, chunksize: int | None = None, method: str | None = None): nrows = len(self.frame) if nrows == 0: - return + return 0 if chunksize is None: chunksize = nrows @@ -910,7 +930,7 @@ def insert(self, chunksize: int | None = None, method: str | None = None): raise ValueError("chunksize argument should be non-zero") chunks = (nrows // chunksize) + 1 - + total_inserted = 0 with self.pd_sql.run_transaction() as conn: for i in range(chunks): start_i = i * chunksize @@ -919,7 +939,12 @@ def insert(self, chunksize: int | None = None, method: str | None = None): break chunk_iter = zip(*(arr[start_i:end_i] for arr in data_list)) - exec_insert(conn, keys, chunk_iter) + num_inserted = exec_insert(conn, keys, chunk_iter) + if num_inserted is None: + total_inserted = None + else: + total_inserted += num_inserted + return total_inserted def _query_iterator( self, @@ -1239,7 +1264,7 @@ def to_sql( chunksize=None, dtype: DtypeArg | None = None, method=None, - ): + ) -> int | None: raise ValueError( "PandasSQL must be created with an SQLAlchemy " "connectable or sqlite connection" @@ -1258,7 +1283,7 @@ def insert_records( chunksize=None, method=None, **engine_kwargs, - ): + ) -> int | None: """ Inserts data into already-prepared table """ @@ -1282,11 +1307,11 @@ def insert_records( chunksize=None, method=None, **engine_kwargs, - ): + ) -> int | None: from sqlalchemy import exc try: - table.insert(chunksize=chunksize, method=method) + return table.insert(chunksize=chunksize, method=method) except exc.SQLAlchemyError as err: # GH34431 # https://stackoverflow.com/a/67358288/6067848 @@ -1643,7 +1668,7 @@ def to_sql( method=None, engine="auto", **engine_kwargs, - ): + ) -> int | None: """ Write records stored in a DataFrame to a SQL database. @@ -1704,7 +1729,7 @@ def to_sql( dtype=dtype, ) - sql_engine.insert_records( + total_inserted = sql_engine.insert_records( table=table, con=self.connectable, frame=frame, @@ -1717,6 +1742,7 @@ def to_sql( ) self.check_case_sensitive(name=name, schema=schema) + return total_inserted @property def tables(self): @@ -1859,14 +1885,16 @@ def insert_statement(self, *, num_rows: int): ) return insert_statement - def _execute_insert(self, conn, keys, data_iter): + def _execute_insert(self, conn, keys, data_iter) -> int: data_list = list(data_iter) conn.executemany(self.insert_statement(num_rows=1), data_list) + return conn.rowcount - def _execute_insert_multi(self, conn, keys, data_iter): + def _execute_insert_multi(self, conn, keys, data_iter) -> int: data_list = list(data_iter) flattened_data = [x for row in data_list for x in row] conn.execute(self.insert_statement(num_rows=len(data_list)), flattened_data) + return conn.rowcount def _create_table_setup(self): """ @@ -2088,7 +2116,7 @@ def to_sql( dtype: DtypeArg | None = None, method=None, **kwargs, - ): + ) -> int | None: """ Write records stored in a DataFrame to a SQL database. @@ -2153,7 +2181,7 @@ def to_sql( dtype=dtype, ) table.create() - table.insert(chunksize, method) + return table.insert(chunksize, method) def has_table(self, name: str, schema: str | None = None): diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index 2f988d825d9db..3eba3762d2a34 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -726,15 +726,18 @@ def _read_sql_iris_no_parameter_with_percent(self): def _to_sql_empty(self, test_frame1): self.drop_table("test_frame1") - self.pandasSQL.to_sql(test_frame1.iloc[:0], "test_frame1") + assert self.pandasSQL.to_sql(test_frame1.iloc[:0], "test_frame1") == 0 def _to_sql_with_sql_engine(self, test_frame1, engine="auto", **engine_kwargs): """`to_sql` with the `engine` param""" # mostly copied from this class's `_to_sql()` method self.drop_table("test_frame1") - self.pandasSQL.to_sql( - test_frame1, "test_frame1", engine=engine, **engine_kwargs + assert ( + self.pandasSQL.to_sql( + test_frame1, "test_frame1", engine=engine, **engine_kwargs + ) + == 4 ) assert self.pandasSQL.has_table("test_frame1") @@ -747,7 +750,7 @@ def _to_sql_with_sql_engine(self, test_frame1, engine="auto", **engine_kwargs): def _roundtrip(self, test_frame1): self.drop_table("test_frame_roundtrip") - self.pandasSQL.to_sql(test_frame1, "test_frame_roundtrip") + assert self.pandasSQL.to_sql(test_frame1, "test_frame_roundtrip") == 4 result = self.pandasSQL.read_query("SELECT * FROM test_frame_roundtrip") result.set_index("level_0", inplace=True) @@ -767,7 +770,7 @@ def _to_sql_save_index(self): df = DataFrame.from_records( [(1, 2.1, "line1"), (2, 1.5, "line2")], columns=["A", "B", "C"], index=["A"] ) - self.pandasSQL.to_sql(df, "test_to_sql_saves_index") + assert self.pandasSQL.to_sql(df, "test_to_sql_saves_index") == 2 ix_cols = self._get_index_columns("test_to_sql_saves_index") assert ix_cols == [["A"]] @@ -876,10 +879,12 @@ def test_to_sql_replace(self, test_frame1): assert num_rows == num_entries def test_to_sql_append(self, test_frame1): - sql.to_sql(test_frame1, "test_frame4", self.conn, if_exists="fail") + assert sql.to_sql(test_frame1, "test_frame4", self.conn, if_exists="fail") == 4 # Add to table again - sql.to_sql(test_frame1, "test_frame4", self.conn, if_exists="append") + assert ( + sql.to_sql(test_frame1, "test_frame4", self.conn, if_exists="append") == 4 + ) assert sql.has_table("test_frame4", self.conn) num_entries = 2 * len(test_frame1) @@ -1030,7 +1035,8 @@ def test_timedelta(self): # see #6921 df = to_timedelta(Series(["00:00:01", "00:00:03"], name="foo")).to_frame() with tm.assert_produces_warning(UserWarning): - df.to_sql("test_timedelta", self.conn) + result_count = df.to_sql("test_timedelta", self.conn) + assert result_count is None result = sql.read_sql_query("SELECT * FROM test_timedelta", self.conn) tm.assert_series_equal(result["foo"], df["foo"].view("int64")) @@ -1038,7 +1044,7 @@ def test_complex_raises(self): df = DataFrame({"a": [1 + 1j, 2j]}) msg = "Complex datatypes not supported" with pytest.raises(ValueError, match=msg): - df.to_sql("test_complex", self.conn) + assert df.to_sql("test_complex", self.conn) is None @pytest.mark.parametrize( "index_name,index_label,expected", @@ -1066,42 +1072,49 @@ def test_to_sql_index_label(self, index_name, index_label, expected): assert frame.columns[0] == expected def test_to_sql_index_label_multiindex(self): + expected_row_count = 4 temp_frame = DataFrame( {"col1": range(4)}, index=MultiIndex.from_product([("A0", "A1"), ("B0", "B1")]), ) # no index name, defaults to 'level_0' and 'level_1' - sql.to_sql(temp_frame, "test_index_label", self.conn) + result = sql.to_sql(temp_frame, "test_index_label", self.conn) + assert result == expected_row_count frame = sql.read_sql_query("SELECT * FROM test_index_label", self.conn) assert frame.columns[0] == "level_0" assert frame.columns[1] == "level_1" # specifying index_label - sql.to_sql( + result = sql.to_sql( temp_frame, "test_index_label", self.conn, if_exists="replace", index_label=["A", "B"], ) + assert result == expected_row_count frame = sql.read_sql_query("SELECT * FROM test_index_label", self.conn) assert frame.columns[:2].tolist() == ["A", "B"] # using the index name temp_frame.index.names = ["A", "B"] - sql.to_sql(temp_frame, "test_index_label", self.conn, if_exists="replace") + result = sql.to_sql( + temp_frame, "test_index_label", self.conn, if_exists="replace" + ) + assert result == expected_row_count frame = sql.read_sql_query("SELECT * FROM test_index_label", self.conn) assert frame.columns[:2].tolist() == ["A", "B"] # has index name, but specifying index_label - sql.to_sql( + result = sql.to_sql( temp_frame, "test_index_label", self.conn, if_exists="replace", index_label=["C", "D"], ) + assert result == expected_row_count frame = sql.read_sql_query("SELECT * FROM test_index_label", self.conn) assert frame.columns[:2].tolist() == ["C", "D"] @@ -1140,7 +1153,7 @@ def test_multiindex_roundtrip(self): def test_dtype_argument(self, dtype): # GH10285 Add dtype argument to read_sql_query df = DataFrame([[1.2, 3.4], [5.6, 7.8]], columns=["A", "B"]) - df.to_sql("test_dtype_argument", self.conn) + assert df.to_sql("test_dtype_argument", self.conn) == 2 expected = df.astype(dtype) result = sql.read_sql_query( @@ -1505,7 +1518,7 @@ def test_sql_open_close(self, test_frame3): with tm.ensure_clean() as name: conn = self.connect(name) - sql.to_sql(test_frame3, "test_frame3_legacy", conn, index=False) + assert sql.to_sql(test_frame3, "test_frame3_legacy", conn, index=False) == 4 conn.close() conn = self.connect(name) @@ -1621,7 +1634,7 @@ def test_create_table(self): ) pandasSQL = sql.SQLDatabase(temp_conn) - pandasSQL.to_sql(temp_frame, "temp_frame") + assert pandasSQL.to_sql(temp_frame, "temp_frame") == 4 if _gt14(): from sqlalchemy import inspect @@ -1639,7 +1652,7 @@ def test_drop_table(self): ) pandasSQL = sql.SQLDatabase(temp_conn) - pandasSQL.to_sql(temp_frame, "temp_frame") + assert pandasSQL.to_sql(temp_frame, "temp_frame") == 4 if _gt14(): from sqlalchemy import inspect @@ -1692,7 +1705,7 @@ def test_default_type_conversion(self): def test_bigint(self): # int64 should be converted to BigInteger, GH7433 df = DataFrame(data={"i64": [2 ** 62]}) - df.to_sql("test_bigint", self.conn, index=False) + assert df.to_sql("test_bigint", self.conn, index=False) == 1 result = sql.read_sql_table("test_bigint", self.conn) tm.assert_frame_equal(df, result) @@ -1788,7 +1801,7 @@ def test_datetime_with_timezone_roundtrip(self): expected = DataFrame( {"A": date_range("2013-01-01 09:00:00", periods=3, tz="US/Pacific")} ) - expected.to_sql("test_datetime_tz", self.conn, index=False) + assert expected.to_sql("test_datetime_tz", self.conn, index=False) == 3 if self.flavor == "postgresql": # SQLAlchemy "timezones" (i.e. offsets) are coerced to UTC @@ -1810,7 +1823,7 @@ def test_datetime_with_timezone_roundtrip(self): def test_out_of_bounds_datetime(self): # GH 26761 data = DataFrame({"date": datetime(9999, 1, 1)}, index=[0]) - data.to_sql("test_datetime_obb", self.conn, index=False) + assert data.to_sql("test_datetime_obb", self.conn, index=False) == 1 result = sql.read_sql_table("test_datetime_obb", self.conn) expected = DataFrame([pd.NaT], columns=["date"]) tm.assert_frame_equal(result, expected) @@ -1820,7 +1833,7 @@ def test_naive_datetimeindex_roundtrip(self): # Ensure that a naive DatetimeIndex isn't converted to UTC dates = date_range("2018-01-01", periods=5, freq="6H")._with_freq(None) expected = DataFrame({"nums": range(5)}, index=dates) - expected.to_sql("foo_table", self.conn, index_label="info_date") + assert expected.to_sql("foo_table", self.conn, index_label="info_date") == 5 result = sql.read_sql_table("foo_table", self.conn, index_col="info_date") # result index with gain a name from a set_index operation; expected tm.assert_frame_equal(result, expected, check_names=False) @@ -1861,7 +1874,7 @@ def test_datetime(self): df = DataFrame( {"A": date_range("2013-01-01 09:00:00", periods=3), "B": np.arange(3.0)} ) - df.to_sql("test_datetime", self.conn) + assert df.to_sql("test_datetime", self.conn) == 3 # with read_table -> type information from schema used result = sql.read_sql_table("test_datetime", self.conn) @@ -1883,7 +1896,7 @@ def test_datetime_NaT(self): {"A": date_range("2013-01-01 09:00:00", periods=3), "B": np.arange(3.0)} ) df.loc[1, "A"] = np.nan - df.to_sql("test_datetime", self.conn, index=False) + assert df.to_sql("test_datetime", self.conn, index=False) == 3 # with read_table -> type information from schema used result = sql.read_sql_table("test_datetime", self.conn) @@ -1901,7 +1914,7 @@ def test_datetime_NaT(self): def test_datetime_date(self): # test support for datetime.date df = DataFrame([date(2014, 1, 1), date(2014, 1, 2)], columns=["a"]) - df.to_sql("test_date", self.conn, index=False) + assert df.to_sql("test_date", self.conn, index=False) == 2 res = read_sql_table("test_date", self.conn) result = res["a"] expected = to_datetime(df["a"]) @@ -1911,19 +1924,19 @@ def test_datetime_date(self): def test_datetime_time(self): # test support for datetime.time df = DataFrame([time(9, 0, 0), time(9, 1, 30)], columns=["a"]) - df.to_sql("test_time", self.conn, index=False) + assert df.to_sql("test_time", self.conn, index=False) == 2 res = read_sql_table("test_time", self.conn) tm.assert_frame_equal(res, df) # GH8341 # first, use the fallback to have the sqlite adapter put in place sqlite_conn = TestSQLiteFallback.connect() - sql.to_sql(df, "test_time2", sqlite_conn, index=False) + assert sql.to_sql(df, "test_time2", sqlite_conn, index=False) == 2 res = sql.read_sql_query("SELECT * FROM test_time2", sqlite_conn) ref = df.applymap(lambda _: _.strftime("%H:%M:%S.%f")) tm.assert_frame_equal(ref, res) # check if adapter is in place # then test if sqlalchemy is unaffected by the sqlite adapter - sql.to_sql(df, "test_time3", self.conn, index=False) + assert sql.to_sql(df, "test_time3", self.conn, index=False) == 2 if self.flavor == "sqlite": res = sql.read_sql_query("SELECT * FROM test_time3", self.conn) ref = df.applymap(lambda _: _.strftime("%H:%M:%S.%f")) @@ -1938,7 +1951,7 @@ def test_mixed_dtype_insert(self): df = DataFrame({"s1": s1, "s2": s2}) # write and read again - df.to_sql("test_read_write", self.conn, index=False) + assert df.to_sql("test_read_write", self.conn, index=False) == 1 df2 = sql.read_sql_table("test_read_write", self.conn) tm.assert_frame_equal(df, df2, check_dtype=False, check_exact=True) @@ -1946,7 +1959,7 @@ def test_mixed_dtype_insert(self): def test_nan_numeric(self): # NaNs in numeric float column df = DataFrame({"A": [0, 1, 2], "B": [0.2, np.nan, 5.6]}) - df.to_sql("test_nan", self.conn, index=False) + assert df.to_sql("test_nan", self.conn, index=False) == 3 # with read_table result = sql.read_sql_table("test_nan", self.conn) @@ -1959,7 +1972,7 @@ def test_nan_numeric(self): def test_nan_fullcolumn(self): # full NaN column (numeric float column) df = DataFrame({"A": [0, 1, 2], "B": [np.nan, np.nan, np.nan]}) - df.to_sql("test_nan", self.conn, index=False) + assert df.to_sql("test_nan", self.conn, index=False) == 3 # with read_table result = sql.read_sql_table("test_nan", self.conn) @@ -1974,7 +1987,7 @@ def test_nan_fullcolumn(self): def test_nan_string(self): # NaNs in string column df = DataFrame({"A": [0, 1, 2], "B": ["a", "b", np.nan]}) - df.to_sql("test_nan", self.conn, index=False) + assert df.to_sql("test_nan", self.conn, index=False) == 3 # NaNs are coming back as None df.loc[2, "B"] = None @@ -2035,8 +2048,8 @@ def test_dtype(self): cols = ["A", "B"] data = [(0.8, True), (0.9, None)] df = DataFrame(data, columns=cols) - df.to_sql("dtype_test", self.conn) - df.to_sql("dtype_test2", self.conn, dtype={"B": TEXT}) + assert df.to_sql("dtype_test", self.conn) == 2 + assert df.to_sql("dtype_test2", self.conn, dtype={"B": TEXT}) == 2 meta = MetaData() meta.reflect(bind=self.conn) sqltype = meta.tables["dtype_test2"].columns["B"].type @@ -2046,14 +2059,14 @@ def test_dtype(self): df.to_sql("error", self.conn, dtype={"B": str}) # GH9083 - df.to_sql("dtype_test3", self.conn, dtype={"B": String(10)}) + assert df.to_sql("dtype_test3", self.conn, dtype={"B": String(10)}) == 2 meta.reflect(bind=self.conn) sqltype = meta.tables["dtype_test3"].columns["B"].type assert isinstance(sqltype, String) assert sqltype.length == 10 # single dtype - df.to_sql("single_dtype_test", self.conn, dtype=TEXT) + assert df.to_sql("single_dtype_test", self.conn, dtype=TEXT) == 2 meta.reflect(bind=self.conn) sqltypea = meta.tables["single_dtype_test"].columns["A"].type sqltypeb = meta.tables["single_dtype_test"].columns["B"].type @@ -2078,7 +2091,7 @@ def test_notna_dtype(self): df = DataFrame(cols) tbl = "notna_dtype_test" - df.to_sql(tbl, self.conn) + assert df.to_sql(tbl, self.conn) == 2 _ = sql.read_sql_table(tbl, self.conn) meta = MetaData() meta.reflect(bind=self.conn) @@ -2109,12 +2122,15 @@ def test_double_precision(self): } ) - df.to_sql( - "test_dtypes", - self.conn, - index=False, - if_exists="replace", - dtype={"f64_as_f32": Float(precision=23)}, + assert ( + df.to_sql( + "test_dtypes", + self.conn, + index=False, + if_exists="replace", + dtype={"f64_as_f32": Float(precision=23)}, + ) + == 1 ) res = sql.read_sql_table("test_dtypes", self.conn) @@ -2161,7 +2177,10 @@ def main(connectable): else: baz(connectable) - DataFrame({"test_foo_data": [0, 1, 2]}).to_sql("test_foo_data", self.conn) + assert ( + DataFrame({"test_foo_data": [0, 1, 2]}).to_sql("test_foo_data", self.conn) + == 3 + ) main(self.conn) @pytest.mark.parametrize( @@ -2188,7 +2207,7 @@ def test_to_sql_with_negative_npinf(self, input, request): with pytest.raises(ValueError, match=msg): df.to_sql("foobar", self.conn, index=False) else: - df.to_sql("foobar", self.conn, index=False) + assert df.to_sql("foobar", self.conn, index=False) == 1 res = sql.read_sql_table("foobar", self.conn) tm.assert_equal(df, res) @@ -2319,7 +2338,7 @@ def test_default_date_load(self): def test_bigint_warning(self): # test no warning for BIGINT (to support int64) is raised (GH7433) df = DataFrame({"a": [1, 2]}, dtype="int64") - df.to_sql("test_bigintwarning", self.conn, index=False) + assert df.to_sql("test_bigintwarning", self.conn, index=False) == 2 with tm.assert_produces_warning(None): sql.read_sql_table("test_bigintwarning", self.conn) @@ -2353,7 +2372,10 @@ class Test(BaseModel): session = Session() df = DataFrame({"id": [0, 1], "foo": ["hello", "world"]}) - df.to_sql("test_frame", con=self.conn, index=False, if_exists="replace") + assert ( + df.to_sql("test_frame", con=self.conn, index=False, if_exists="replace") + == 2 + ) session.commit() foo = session.query(Test.id, Test.foo) @@ -2421,11 +2443,16 @@ def test_schema_support(self): self.conn.execute("CREATE SCHEMA other;") # write dataframe to different schema's - df.to_sql("test_schema_public", self.conn, index=False) - df.to_sql( - "test_schema_public_explicit", self.conn, index=False, schema="public" + assert df.to_sql("test_schema_public", self.conn, index=False) == 2 + assert ( + df.to_sql( + "test_schema_public_explicit", self.conn, index=False, schema="public" + ) + == 2 + ) + assert ( + df.to_sql("test_schema_other", self.conn, index=False, schema="other") == 2 ) - df.to_sql("test_schema_other", self.conn, index=False, schema="other") # read dataframes back in res1 = sql.read_sql_table("test_schema_public", self.conn) @@ -2449,7 +2476,9 @@ def test_schema_support(self): self.conn.execute("CREATE SCHEMA other;") # write dataframe with different if_exists options - df.to_sql("test_schema_other", self.conn, schema="other", index=False) + assert ( + df.to_sql("test_schema_other", self.conn, schema="other", index=False) == 2 + ) df.to_sql( "test_schema_other", self.conn, @@ -2457,12 +2486,15 @@ def test_schema_support(self): index=False, if_exists="replace", ) - df.to_sql( - "test_schema_other", - self.conn, - schema="other", - index=False, - if_exists="append", + assert ( + df.to_sql( + "test_schema_other", + self.conn, + schema="other", + index=False, + if_exists="append", + ) + == 2 ) res = sql.read_sql_table("test_schema_other", self.conn, schema="other") tm.assert_frame_equal(concat([df, df], ignore_index=True), res) @@ -2474,9 +2506,15 @@ def test_schema_support(self): if isinstance(self.conn, Engine): engine2 = self.connect() pdsql = sql.SQLDatabase(engine2, schema="other") - pdsql.to_sql(df, "test_schema_other2", index=False) - pdsql.to_sql(df, "test_schema_other2", index=False, if_exists="replace") - pdsql.to_sql(df, "test_schema_other2", index=False, if_exists="append") + assert pdsql.to_sql(df, "test_schema_other2", index=False) == 2 + assert ( + pdsql.to_sql(df, "test_schema_other2", index=False, if_exists="replace") + == 2 + ) + assert ( + pdsql.to_sql(df, "test_schema_other2", index=False, if_exists="append") + == 2 + ) res1 = sql.read_sql_table("test_schema_other2", self.conn, schema="other") res2 = pdsql.read_table("test_schema_other2") tm.assert_frame_equal(res1, res2) @@ -2554,7 +2592,7 @@ def test_create_and_drop_table(self): {"one": [1.0, 2.0, 3.0, 4.0], "two": [4.0, 3.0, 2.0, 1.0]} ) - self.pandasSQL.to_sql(temp_frame, "drop_test_frame") + assert self.pandasSQL.to_sql(temp_frame, "drop_test_frame") == 4 assert self.pandasSQL.has_table("drop_test_frame") @@ -2571,7 +2609,7 @@ def test_execute_sql(self): def test_datetime_date(self): # test support for datetime.date df = DataFrame([date(2014, 1, 1), date(2014, 1, 2)], columns=["a"]) - df.to_sql("test_date", self.conn, index=False) + assert df.to_sql("test_date", self.conn, index=False) == 2 res = read_sql_query("SELECT * FROM test_date", self.conn) if self.flavor == "sqlite": # comes back as strings @@ -2582,7 +2620,7 @@ def test_datetime_date(self): def test_datetime_time(self): # test support for datetime.time, GH #8341 df = DataFrame([time(9, 0, 0), time(9, 1, 30)], columns=["a"]) - df.to_sql("test_time", self.conn, index=False) + assert df.to_sql("test_time", self.conn, index=False) == 2 res = read_sql_query("SELECT * FROM test_time", self.conn) if self.flavor == "sqlite": # comes back as strings @@ -2620,8 +2658,8 @@ def test_dtype(self): cols = ["A", "B"] data = [(0.8, True), (0.9, None)] df = DataFrame(data, columns=cols) - df.to_sql("dtype_test", self.conn) - df.to_sql("dtype_test2", self.conn, dtype={"B": "STRING"}) + assert df.to_sql("dtype_test", self.conn) == 2 + assert df.to_sql("dtype_test2", self.conn, dtype={"B": "STRING"}) == 2 # sqlite stores Boolean values as INTEGER assert self._get_sqlite_column_type("dtype_test", "B") == "INTEGER" @@ -2632,7 +2670,7 @@ def test_dtype(self): df.to_sql("error", self.conn, dtype={"B": bool}) # single dtype - df.to_sql("single_dtype_test", self.conn, dtype="STRING") + assert df.to_sql("single_dtype_test", self.conn, dtype="STRING") == 2 assert self._get_sqlite_column_type("single_dtype_test", "A") == "STRING" assert self._get_sqlite_column_type("single_dtype_test", "B") == "STRING" @@ -2649,7 +2687,7 @@ def test_notna_dtype(self): df = DataFrame(cols) tbl = "notna_dtype_test" - df.to_sql(tbl, self.conn) + assert df.to_sql(tbl, self.conn) == 2 assert self._get_sqlite_column_type(tbl, "Bool") == "INTEGER" assert self._get_sqlite_column_type(tbl, "Date") == "TIMESTAMP" @@ -2678,12 +2716,12 @@ def test_illegal_names(self): "\xe9", ] ): - df.to_sql(weird_name, self.conn) + assert df.to_sql(weird_name, self.conn) == 2 sql.table_exists(weird_name, self.conn) df2 = DataFrame([[1, 2], [3, 4]], columns=["a", weird_name]) c_tbl = f"test_weird_col_name{ndx:d}" - df2.to_sql(c_tbl, self.conn) + assert df2.to_sql(c_tbl, self.conn) == 2 sql.table_exists(c_tbl, self.conn) @@ -2736,7 +2774,7 @@ def drop_table(self, table_name): def test_basic(self): frame = tm.makeTimeDataFrame() - sql.to_sql(frame, name="test_table", con=self.conn, index=False) + assert sql.to_sql(frame, name="test_table", con=self.conn, index=False) == 30 result = sql.read_sql("select * from test_table", self.conn) # HACK! Change this once indexes are handled properly. @@ -2749,7 +2787,7 @@ def test_basic(self): frame2 = frame.copy() new_idx = Index(np.arange(len(frame2))) + 10 frame2["Idx"] = new_idx.copy() - sql.to_sql(frame2, name="test_table2", con=self.conn, index=False) + assert sql.to_sql(frame2, name="test_table2", con=self.conn, index=False) == 30 result = sql.read_sql("select * from test_table2", self.conn, index_col="Idx") expected = frame.copy() expected.index = new_idx @@ -2845,14 +2883,14 @@ def test_execute_closed_connection(self): def test_keyword_as_column_names(self): df = DataFrame({"From": np.ones(5)}) - sql.to_sql(df, con=self.conn, name="testkeywords", index=False) + assert sql.to_sql(df, con=self.conn, name="testkeywords", index=False) == 5 def test_onecolumn_of_integer(self): # GH 3628 # a column_of_integers dataframe should transfer well to sql mono_df = DataFrame([1, 2], columns=["c0"]) - sql.to_sql(mono_df, con=self.conn, name="mono_df", index=False) + assert sql.to_sql(mono_df, con=self.conn, name="mono_df", index=False) == 2 # computing the sum via sql con_x = self.conn the_sum = sum(my_c0[0] for my_c0 in con_x.execute("select * from mono_df")) @@ -2896,31 +2934,40 @@ def test_if_exists(self): index=False, ) assert tquery(sql_select, con=self.conn) == [(1, "A"), (2, "B")] - sql.to_sql( - frame=df_if_exists_2, - con=self.conn, - name=table_name, - if_exists="replace", - index=False, + assert ( + sql.to_sql( + frame=df_if_exists_2, + con=self.conn, + name=table_name, + if_exists="replace", + index=False, + ) + == 3 ) assert tquery(sql_select, con=self.conn) == [(3, "C"), (4, "D"), (5, "E")] self.drop_table(table_name) # test if_exists='append' - sql.to_sql( - frame=df_if_exists_1, - con=self.conn, - name=table_name, - if_exists="fail", - index=False, + assert ( + sql.to_sql( + frame=df_if_exists_1, + con=self.conn, + name=table_name, + if_exists="fail", + index=False, + ) + == 2 ) assert tquery(sql_select, con=self.conn) == [(1, "A"), (2, "B")] - sql.to_sql( - frame=df_if_exists_2, - con=self.conn, - name=table_name, - if_exists="append", - index=False, + assert ( + sql.to_sql( + frame=df_if_exists_2, + con=self.conn, + name=table_name, + if_exists="append", + index=False, + ) + == 3 ) assert tquery(sql_select, con=self.conn) == [ (1, "A"), From e20db950976974fb43ce7117e12c8f08ba014ed4 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Thu, 30 Dec 2021 20:44:06 -0800 Subject: [PATCH 2/8] Fix multi --- pandas/io/sql.py | 4 ++-- pandas/tests/io/test_sql.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pandas/io/sql.py b/pandas/io/sql.py index f70bb03b3ff2f..8df9ad141213c 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -862,8 +862,8 @@ def _execute_insert_multi(self, conn, keys: list[str], data_iter) -> int: data = [dict(zip(keys, row)) for row in data_iter] stmt = insert(self.table).values(data) - conn.execute(stmt) - return conn.rowcount + result = conn.execute(stmt) + return result.rowcount def insert_data(self): if self.index is not None: diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index 3eba3762d2a34..c1007a01a3829 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -1036,7 +1036,7 @@ def test_timedelta(self): df = to_timedelta(Series(["00:00:01", "00:00:03"], name="foo")).to_frame() with tm.assert_produces_warning(UserWarning): result_count = df.to_sql("test_timedelta", self.conn) - assert result_count is None + assert result_count == 2 result = sql.read_sql_query("SELECT * FROM test_timedelta", self.conn) tm.assert_series_equal(result["foo"], df["foo"].view("int64")) From b1ce124ead3cf7f5555bad2726ee6ee6217f812d Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Thu, 30 Dec 2021 21:39:35 -0800 Subject: [PATCH 3/8] Fix doctests --- pandas/core/generic.py | 1 + pandas/io/sql.py | 1 + 2 files changed, 2 insertions(+) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 4235a14c6c63a..cb11ff7bf7df6 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -2873,6 +2873,7 @@ def to_sql( 2 User 3 >>> df.to_sql('users', con=engine) + 3 >>> engine.execute("SELECT * FROM users").fetchall() [(0, 'User 1'), (1, 'User 2'), (2, 'User 3')] diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 8df9ad141213c..6b969ff8959dd 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -512,6 +512,7 @@ def read_sql( >>> df = pd.DataFrame(data=[[0, '10/11/12'], [1, '12/11/10']], ... columns=['int_column', 'date_column']) >>> df.to_sql('test_data', conn) + 2 >>> pd.read_sql('SELECT int_column, date_column FROM test_data', conn) int_column date_column From 0b7d4188b2f8ca50cba875168997c568d6849d0a Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Thu, 30 Dec 2021 22:34:05 -0800 Subject: [PATCH 4/8] Fix more doctests --- pandas/core/generic.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index cb11ff7bf7df6..9989b0b55a07f 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -2882,12 +2882,13 @@ def to_sql( >>> with engine.begin() as connection: ... df1 = pd.DataFrame({'name' : ['User 4', 'User 5']}) ... df1.to_sql('users', con=connection, if_exists='append') - + 2 This is allowed to support operations that require that the same DBAPI connection is used for the entire operation. >>> df2 = pd.DataFrame({'name' : ['User 6', 'User 7']}) >>> df2.to_sql('users', con=engine, if_exists='append') + 2 >>> engine.execute("SELECT * FROM users").fetchall() [(0, 'User 1'), (1, 'User 2'), (2, 'User 3'), (0, 'User 4'), (1, 'User 5'), (0, 'User 6'), @@ -2897,6 +2898,7 @@ def to_sql( >>> df2.to_sql('users', con=engine, if_exists='replace', ... index_label='id') + 2 >>> engine.execute("SELECT * FROM users").fetchall() [(0, 'User 6'), (1, 'User 7')] @@ -2915,6 +2917,7 @@ def to_sql( >>> from sqlalchemy.types import Integer >>> df.to_sql('integers', con=engine, index=False, ... dtype={"A": Integer()}) + 3 >>> engine.execute("SELECT * FROM integers").fetchall() [(1,), (None,), (2,)] From f25eb6ba6e8231dc082b75d1bd1def504be4ef6d Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Thu, 30 Dec 2021 22:59:01 -0800 Subject: [PATCH 5/8] Space --- pandas/core/generic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 9989b0b55a07f..f5adaeedac86c 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -2883,6 +2883,7 @@ def to_sql( ... df1 = pd.DataFrame({'name' : ['User 4', 'User 5']}) ... df1.to_sql('users', con=connection, if_exists='append') 2 + This is allowed to support operations that require that the same DBAPI connection is used for the entire operation. From cdd0951d5dd3526701206f8ae004631695bbc0ac Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Fri, 31 Dec 2021 09:22:01 -0800 Subject: [PATCH 6/8] docstring validation --- pandas/core/generic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index f5adaeedac86c..f3485d2367f53 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -2830,7 +2830,7 @@ def to_sql( attribute of ``sqlite3.Cursor`` or SQLAlchemy connectable which may not reflect the exact number of written rows as stipulated in the `sqlite3 `__ or - `SQLAlchemy `__ + `SQLAlchemy `__. .. versionadded:: 1.4.0 From 94d0188867e5918a26ff82d2cedd821d2a17e9e8 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Fri, 31 Dec 2021 10:19:28 -0800 Subject: [PATCH 7/8] Fix typing --- pandas/core/generic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index f3485d2367f53..16bc6a14c68a4 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -2767,7 +2767,7 @@ def to_sql( chunksize=None, dtype: DtypeArg | None = None, method=None, - ) -> None: + ) -> int | None: """ Write records stored in a DataFrame to a SQL database. From 74c48d8daf7f8cb68c49b4b8b4356e1e2d8fdea3 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke Date: Fri, 31 Dec 2021 15:20:01 -0800 Subject: [PATCH 8/8] fix noqa --- pandas/core/generic.py | 2 +- pandas/io/sql.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 16bc6a14c68a4..1e25b0f4eb176 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -2922,7 +2922,7 @@ def to_sql( >>> engine.execute("SELECT * FROM integers").fetchall() [(1,), (None,), (2,)] - """ # noqa: 501 + """ # noqa:E501 from pandas.io import sql return sql.to_sql( diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 6b969ff8959dd..022ed2df8598d 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -681,7 +681,7 @@ def to_sql( rows as stipulated in the `sqlite3 `__ or `SQLAlchemy `__ - """ # noqa: 501 + """ # noqa:E501 if if_exists not in ("fail", "replace", "append"): raise ValueError(f"'{if_exists}' is not valid for if_exists")