diff --git a/pandas/io/parquet.py b/pandas/io/parquet.py index b4330349a8c9d..71099a56642f1 100644 --- a/pandas/io/parquet.py +++ b/pandas/io/parquet.py @@ -21,6 +21,7 @@ DataFrame, MultiIndex, arrays, + concat, get_option, ) from pandas.core.shared_docs import _shared_docs @@ -218,8 +219,9 @@ def write( def read( self, - path, + path: FilePath, columns=None, + n_rows: int = None, use_nullable_dtypes: bool = False, storage_options: StorageOptions = None, **kwargs, @@ -246,22 +248,40 @@ def read( mode="rb", ) try: - pa_table = self.api.parquet.read_table( - path_or_handle, columns=columns, **kwargs - ) - if dtype_backend == "pandas": - result = pa_table.to_pandas(**to_pandas_kwargs) - elif dtype_backend == "pyarrow": - result = DataFrame( - { - col_name: arrays.ArrowExtensionArray(pa_col) - for col_name, pa_col in zip( - pa_table.column_names, pa_table.itercolumns() - ) - } + if not n_rows: + pa_table = self.api.parquet.read_table( + path_or_handle, columns=columns, **kwargs ) - if manager == "array": - result = result._as_manager("array", copy=False) + if dtype_backend == "pandas": + result = pa_table.to_pandas(**to_pandas_kwargs) + elif dtype_backend == "pyarrow": + result = DataFrame( + { + col_name: arrays.ArrowExtensionArray(pa_col) + for col_name, pa_col in zip( + pa_table.column_names, pa_table.itercolumns() + ) + } + ) + if manager == "array": + result = result._as_manager("array", copy=False) + else: + batch_size = 65536 + counter = 0 + batches = [] + if n_rows < batch_size: + batch_size = n_rows + for batch in self.api.parquet.ParquetFile(source=path).iter_batches( + batch_size=batch_size, + columns=columns, + use_pandas_metadata=kwargs["use_pandas_metadata"], + ): + batches.append(batch.to_pandas(**to_pandas_kwargs)) + counter += batch_size + if counter >= n_rows: + break + result = concat(batches) + return result finally: if handles is not None: @@ -325,7 +345,12 @@ def write( ) def read( - self, path, columns=None, storage_options: StorageOptions = None, **kwargs + self, + path, + columns=None, + n_rows: int = None, + storage_options: StorageOptions = None, + **kwargs, ) -> DataFrame: parquet_kwargs: dict[str, Any] = {} use_nullable_dtypes = kwargs.pop("use_nullable_dtypes", False) @@ -361,6 +386,11 @@ def read( try: parquet_file = self.api.ParquetFile(path, **parquet_kwargs) + if n_rows: + return parquet_file.head( + n_rows=n_rows, columns=columns, **kwargs + ) # Only for convenience and to mirror PyArrow impl + # Whole table is still loaded first return parquet_file.to_pandas(columns=columns, **kwargs) finally: if handles is not None: