diff --git a/doc/source/whatsnew/v1.0.0.rst b/doc/source/whatsnew/v1.0.0.rst index 6fd758abb1f33..17818118e587d 100755 --- a/doc/source/whatsnew/v1.0.0.rst +++ b/doc/source/whatsnew/v1.0.0.rst @@ -205,6 +205,8 @@ Other enhancements (:meth:`~DataFrame.to_parquet` / :func:`read_parquet`) using the `'pyarrow'` engine now preserve those data types with pyarrow >= 1.0.0 (:issue:`20612`). - The ``partition_cols`` argument in :meth:`DataFrame.to_parquet` now accepts a string (:issue:`27117`) +- :func:`to_parquet` now appropriately handles the ``schema`` argument for user defined schemas in the pyarrow engine. (:issue: `30270`) + Build Changes ^^^^^^^^^^^^^ @@ -798,7 +800,6 @@ I/O - Bug in :func:`read_json` where default encoding was not set to ``utf-8`` (:issue:`29565`) - Bug in :class:`PythonParser` where str and bytes were being mixed when dealing with the decimal field (:issue:`29650`) - :meth:`read_gbq` now accepts ``progress_bar_type`` to display progress bar while the data downloads. (:issue:`29857`) -- Plotting ^^^^^^^^ diff --git a/pandas/io/parquet.py b/pandas/io/parquet.py index 51eac8d481231..54e44ff33d079 100644 --- a/pandas/io/parquet.py +++ b/pandas/io/parquet.py @@ -91,11 +91,10 @@ def write( self.validate_dataframe(df) path, _, _, _ = get_filepath_or_buffer(path, mode="wb") - from_pandas_kwargs: Dict[str, Any] - if index is None: - from_pandas_kwargs = {} - else: - from_pandas_kwargs = {"preserve_index": index} + from_pandas_kwargs: Dict[str, Any] = {"schema": kwargs.pop("schema", None)} + if index is not None: + from_pandas_kwargs["preserve_index"] = index + table = self.api.Table.from_pandas(df, **from_pandas_kwargs) if partition_cols is not None: self.api.parquet.write_to_dataset( diff --git a/pandas/tests/io/test_parquet.py b/pandas/tests/io/test_parquet.py index 251548e7caaab..0b8677d6b1415 100644 --- a/pandas/tests/io/test_parquet.py +++ b/pandas/tests/io/test_parquet.py @@ -517,6 +517,14 @@ def test_empty_dataframe(self, pa): df = pd.DataFrame() check_round_trip(df, pa) + def test_write_with_schema(self, pa): + import pyarrow + + df = pd.DataFrame({"x": [0, 1]}) + schema = pyarrow.schema([pyarrow.field("x", type=pyarrow.bool_())]) + out_df = df.astype(bool) + check_round_trip(df, pa, write_kwargs={"schema": schema}, expected=out_df) + @pytest.mark.skip(reason="broken test") @td.skip_if_no("pyarrow", min_version="0.15.0") def test_additional_extension_arrays(self, pa):