diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 5980e3d133374..be870b9fcab1d 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -732,7 +732,6 @@ def to_string( formatter = fmt.DataFrameFormatter( self, - buf=buf, columns=columns, col_space=col_space, na_rep=na_rep, @@ -750,11 +749,7 @@ def to_string( decimal=decimal, line_width=line_width, ) - formatter.to_string() - - if buf is None: - result = formatter.buf.getvalue() - return result + return formatter.to_string(buf=buf) # ---------------------------------------------------------------------- @@ -2273,7 +2268,6 @@ def to_html( formatter = fmt.DataFrameFormatter( self, - buf=buf, columns=columns, col_space=col_space, na_rep=na_rep, @@ -2294,10 +2288,9 @@ def to_html( render_links=render_links, ) # TODO: a generic formatter wld b in DataFrameFormatter - formatter.to_html(classes=classes, notebook=notebook, border=border) - - if buf is None: - return formatter.buf.getvalue() + return formatter.to_html( + buf=buf, classes=classes, notebook=notebook, border=border + ) # ---------------------------------------------------------------------- diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 821c35e0cce2f..1d87a6937ca34 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -3018,7 +3018,6 @@ def to_latex( formatter = DataFrameFormatter( self, - buf=buf, columns=columns, col_space=col_space, na_rep=na_rep, @@ -3032,7 +3031,8 @@ def to_latex( escape=escape, decimal=decimal, ) - formatter.to_latex( + return formatter.to_latex( + buf=buf, column_format=column_format, longtable=longtable, encoding=encoding, @@ -3041,9 +3041,6 @@ def to_latex( multirow=multirow, ) - if buf is None: - return formatter.buf.getvalue() - def to_csv( self, path_or_buf=None, diff --git a/pandas/io/common.py b/pandas/io/common.py index 9a9620e2d0663..e01e473047b88 100644 --- a/pandas/io/common.py +++ b/pandas/io/common.py @@ -10,6 +10,7 @@ import mmap import os import pathlib +from typing import IO, AnyStr, BinaryIO, Optional, TextIO, Type from urllib.error import URLError # noqa from urllib.parse import ( # noqa urlencode, @@ -32,6 +33,8 @@ from pandas.core.dtypes.common import is_file_like +from pandas._typing import FilePathOrBuffer + # gh-12665: Alias for now and remove later. CParserError = ParserError @@ -68,14 +71,14 @@ class BaseIterator: Useful only when the object being iterated is non-reusable (e.g. OK for a parser, not for an in-memory table, yes for its iterator).""" - def __iter__(self): + def __iter__(self) -> "BaseIterator": return self def __next__(self): raise AbstractMethodError(self) -def _is_url(url): +def _is_url(url) -> bool: """Check to see if a URL has a valid protocol. Parameters @@ -93,7 +96,9 @@ def _is_url(url): return False -def _expand_user(filepath_or_buffer): +def _expand_user( + filepath_or_buffer: FilePathOrBuffer[AnyStr] +) -> FilePathOrBuffer[AnyStr]: """Return the argument with an initial component of ~ or ~user replaced by that user's home directory. @@ -111,7 +116,7 @@ def _expand_user(filepath_or_buffer): return filepath_or_buffer -def _validate_header_arg(header): +def _validate_header_arg(header) -> None: if isinstance(header, bool): raise TypeError( "Passing a bool to header is invalid. " @@ -121,7 +126,9 @@ def _validate_header_arg(header): ) -def _stringify_path(filepath_or_buffer): +def _stringify_path( + filepath_or_buffer: FilePathOrBuffer[AnyStr] +) -> FilePathOrBuffer[AnyStr]: """Attempt to convert a path-like object to a string. Parameters @@ -144,13 +151,14 @@ def _stringify_path(filepath_or_buffer): strings, buffers, or anything else that's not even path-like. """ if hasattr(filepath_or_buffer, "__fspath__"): - return filepath_or_buffer.__fspath__() + # https://github.com/python/mypy/issues/1424 + return filepath_or_buffer.__fspath__() # type: ignore elif isinstance(filepath_or_buffer, pathlib.Path): return str(filepath_or_buffer) return _expand_user(filepath_or_buffer) -def is_s3_url(url): +def is_s3_url(url) -> bool: """Check for an s3, s3n, or s3a url""" try: return parse_url(url).scheme in ["s3", "s3n", "s3a"] @@ -158,7 +166,7 @@ def is_s3_url(url): return False -def is_gcs_url(url): +def is_gcs_url(url) -> bool: """Check for a gcs url""" try: return parse_url(url).scheme in ["gcs", "gs"] @@ -167,7 +175,10 @@ def is_gcs_url(url): def get_filepath_or_buffer( - filepath_or_buffer, encoding=None, compression=None, mode=None + filepath_or_buffer: FilePathOrBuffer, + encoding: Optional[str] = None, + compression: Optional[str] = None, + mode: Optional[str] = None, ): """ If the filepath_or_buffer is a url, translate and return the buffer. @@ -190,7 +201,7 @@ def get_filepath_or_buffer( """ filepath_or_buffer = _stringify_path(filepath_or_buffer) - if _is_url(filepath_or_buffer): + if isinstance(filepath_or_buffer, str) and _is_url(filepath_or_buffer): req = urlopen(filepath_or_buffer) content_encoding = req.headers.get("Content-Encoding", None) if content_encoding == "gzip": @@ -224,7 +235,7 @@ def get_filepath_or_buffer( return filepath_or_buffer, None, compression, False -def file_path_to_url(path): +def file_path_to_url(path: str) -> str: """ converts an absolute native path to a FILE URL. @@ -242,7 +253,9 @@ def file_path_to_url(path): _compression_to_extension = {"gzip": ".gz", "bz2": ".bz2", "zip": ".zip", "xz": ".xz"} -def _infer_compression(filepath_or_buffer, compression): +def _infer_compression( + filepath_or_buffer: FilePathOrBuffer, compression: Optional[str] +) -> Optional[str]: """ Get the compression method for filepath_or_buffer. If compression='infer', the inferred compression method is returned. Otherwise, the input @@ -435,7 +448,13 @@ class BytesZipFile(zipfile.ZipFile, BytesIO): # type: ignore """ # GH 17778 - def __init__(self, file, mode, compression=zipfile.ZIP_DEFLATED, **kwargs): + def __init__( + self, + file: FilePathOrBuffer, + mode: str, + compression: int = zipfile.ZIP_DEFLATED, + **kwargs + ): if mode in ["wb", "rb"]: mode = mode.replace("b", "") super().__init__(file, mode, compression, **kwargs) @@ -461,16 +480,16 @@ class MMapWrapper(BaseIterator): """ - def __init__(self, f): + def __init__(self, f: IO): self.mmap = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) - def __getattr__(self, name): + def __getattr__(self, name: str): return getattr(self.mmap, name) - def __iter__(self): + def __iter__(self) -> "MMapWrapper": return self - def __next__(self): + def __next__(self) -> str: newline = self.mmap.readline() # readline returns bytes, not str, but Python's CSV reader @@ -491,16 +510,16 @@ class UTF8Recoder(BaseIterator): Iterator that reads an encoded stream and re-encodes the input to UTF-8 """ - def __init__(self, f, encoding): + def __init__(self, f: BinaryIO, encoding: str): self.reader = codecs.getreader(encoding)(f) - def read(self, bytes=-1): + def read(self, bytes: int = -1) -> bytes: return self.reader.read(bytes).encode("utf-8") - def readline(self): + def readline(self) -> bytes: return self.reader.readline().encode("utf-8") - def next(self): + def next(self) -> bytes: return next(self.reader).encode("utf-8") @@ -511,5 +530,7 @@ def UnicodeReader(f, dialect=csv.excel, encoding="utf-8", **kwds): return csv.reader(f, dialect=dialect, **kwds) -def UnicodeWriter(f, dialect=csv.excel, encoding="utf-8", **kwds): +def UnicodeWriter( + f: TextIO, dialect: Type[csv.Dialect] = csv.excel, encoding: str = "utf-8", **kwds +): return csv.writer(f, dialect=dialect, **kwds) diff --git a/pandas/io/formats/format.py b/pandas/io/formats/format.py index 980fc4888d625..23c07ea72d40f 100644 --- a/pandas/io/formats/format.py +++ b/pandas/io/formats/format.py @@ -2,6 +2,9 @@ Internal module for formatting output data in csv, html, and latex files. This module also applies to display formatting. """ + +import codecs +from contextlib import contextmanager import decimal from functools import partial from io import StringIO @@ -9,6 +12,7 @@ import re from shutil import get_terminal_size from typing import ( + IO, TYPE_CHECKING, Any, Callable, @@ -16,7 +20,6 @@ Iterable, List, Optional, - TextIO, Tuple, Type, Union, @@ -34,6 +37,7 @@ from pandas._libs.tslib import format_array_from_datetime from pandas._libs.tslibs import NaT, Timedelta, Timestamp, iNaT from pandas._libs.tslibs.nattype import NaTType +from pandas.errors import AbstractMethodError from pandas.core.dtypes.common import ( is_categorical_dtype, @@ -67,7 +71,7 @@ from pandas.core.indexes.datetimes import DatetimeIndex from pandas.core.indexes.timedeltas import TimedeltaIndex -from pandas.io.common import _expand_user, _stringify_path +from pandas.io.common import _stringify_path from pandas.io.formats.printing import adjoin, justify, pprint_thing if TYPE_CHECKING: @@ -161,7 +165,7 @@ class CategoricalFormatter: def __init__( self, categorical: "Categorical", - buf: Optional[TextIO] = None, + buf: Optional[IO[str]] = None, length: bool = True, na_rep: str = "NaN", footer: bool = True, @@ -224,7 +228,7 @@ class SeriesFormatter: def __init__( self, series: "Series", - buf: Optional[TextIO] = None, + buf: Optional[IO[str]] = None, length: bool = True, header: bool = True, index: bool = True, @@ -463,6 +467,40 @@ def _get_formatter(self, i: Union[str, int]) -> Optional[Callable]: i = self.columns[i] return self.formatters.get(i, None) + @contextmanager + def get_buffer( + self, buf: Optional[FilePathOrBuffer[str]], encoding: Optional[str] = None + ): + if buf is not None: + buf = _stringify_path(buf) + else: + buf = StringIO() + + if encoding is None: + encoding = "utf-8" + + if hasattr(buf, "write"): + yield buf + elif isinstance(buf, str): + with codecs.open(buf, "w", encoding=encoding) as f: + yield f + else: + raise TypeError("buf is not a file name and it has no write method") + + def write_result(self, buf: IO[str]) -> None: + raise AbstractMethodError(self) + + def get_result( + self, + buf: Optional[FilePathOrBuffer[str]] = None, + encoding: Optional[str] = None, + ) -> Optional[str]: + with self.get_buffer(buf, encoding=encoding) as f: + self.write_result(buf=f) + if buf is None: + return f.getvalue() + return None + class DataFrameFormatter(TableFormatter): """ @@ -480,7 +518,6 @@ class DataFrameFormatter(TableFormatter): def __init__( self, frame: "DataFrame", - buf: Optional[FilePathOrBuffer] = None, columns: Optional[List[str]] = None, col_space: Optional[Union[str, int]] = None, header: Union[bool, List[str]] = True, @@ -502,10 +539,6 @@ def __init__( **kwds ): self.frame = frame - if buf is not None: - self.buf = _expand_user(_stringify_path(buf)) - else: - self.buf = StringIO() self.show_index_names = index_names if sparsify is None: @@ -727,7 +760,7 @@ def _to_str_columns(self) -> List[List[str]]: strcols[ix].insert(row_num + n_header_rows, dot_str) return strcols - def to_string(self) -> None: + def write_result(self, buf: IO[str]) -> None: """ Render a DataFrame to a console-friendly tabular output. """ @@ -782,10 +815,10 @@ def to_string(self) -> None: self._chk_truncate() strcols = self._to_str_columns() text = self.adj.adjoin(1, *strcols) - self.buf.writelines(text) + buf.writelines(text) if self.should_show_dimensions: - self.buf.write( + buf.write( "\n\n[{nrows} rows x {ncols} columns]".format( nrows=len(frame), ncols=len(frame.columns) ) @@ -828,42 +861,33 @@ def _join_multiline(self, *args) -> str: st = ed return "\n\n".join(str_lst) + def to_string(self, buf: Optional[FilePathOrBuffer[str]] = None) -> Optional[str]: + return self.get_result(buf=buf) + def to_latex( self, + buf: Optional[FilePathOrBuffer[str]] = None, column_format: Optional[str] = None, longtable: bool = False, encoding: Optional[str] = None, multicolumn: bool = False, multicolumn_format: Optional[str] = None, multirow: bool = False, - ) -> None: + ) -> Optional[str]: """ Render a DataFrame to a LaTeX tabular/longtable environment output. """ from pandas.io.formats.latex import LatexFormatter - latex_renderer = LatexFormatter( + return LatexFormatter( self, column_format=column_format, longtable=longtable, multicolumn=multicolumn, multicolumn_format=multicolumn_format, multirow=multirow, - ) - - if encoding is None: - encoding = "utf-8" - - if hasattr(self.buf, "write"): - latex_renderer.write_result(self.buf) - elif isinstance(self.buf, str): - import codecs - - with codecs.open(self.buf, "w", encoding=encoding) as f: - latex_renderer.write_result(f) - else: - raise TypeError("buf is not a file name and it has no write method") + ).get_result(buf=buf, encoding=encoding) def _format_col(self, i: int) -> List[str]: frame = self.tr_frame @@ -880,10 +904,11 @@ def _format_col(self, i: int) -> List[str]: def to_html( self, + buf: Optional[FilePathOrBuffer[str]] = None, classes: Optional[Union[str, List, Tuple]] = None, notebook: bool = False, border: Optional[int] = None, - ) -> None: + ) -> Optional[str]: """ Render a DataFrame to a html table. @@ -901,14 +926,7 @@ def to_html( from pandas.io.formats.html import HTMLFormatter, NotebookFormatter Klass = NotebookFormatter if notebook else HTMLFormatter - html = Klass(self, classes=classes, border=border).render() - if hasattr(self.buf, "write"): - buffer_put_lines(self.buf, html) - elif isinstance(self.buf, str): - with open(self.buf, "w") as f: - buffer_put_lines(f, html) - else: - raise TypeError("buf is not a file name and it has no write method") + return Klass(self, classes=classes, border=border).get_result(buf=buf) def _get_formatted_column_labels(self, frame: "DataFrame") -> List[List[str]]: from pandas.core.index import _sparsify @@ -1901,7 +1919,7 @@ def get_level_lengths( return result -def buffer_put_lines(buf: TextIO, lines: List[str]) -> None: +def buffer_put_lines(buf: IO[str], lines: List[str]) -> None: """ Appends lines to a buffer. diff --git a/pandas/io/formats/html.py b/pandas/io/formats/html.py index 19305126f4e5f..4b44893df70ed 100644 --- a/pandas/io/formats/html.py +++ b/pandas/io/formats/html.py @@ -4,7 +4,7 @@ from collections import OrderedDict from textwrap import dedent -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast +from typing import IO, Any, Dict, Iterable, List, Optional, Tuple, Union, cast from pandas._config import get_option @@ -16,6 +16,7 @@ from pandas.io.formats.format import ( DataFrameFormatter, TableFormatter, + buffer_put_lines, get_level_lengths, ) from pandas.io.formats.printing import pprint_thing @@ -203,6 +204,9 @@ def render(self) -> List[str]: return self.elements + def write_result(self, buf: IO[str]) -> None: + buffer_put_lines(buf, self.render()) + def _write_table(self, indent: int = 0) -> None: _classes = ["dataframe"] # Default class. use_mathjax = get_option("display.html.use_mathjax") diff --git a/pandas/tests/io/formats/test_format.py b/pandas/tests/io/formats/test_format.py index ad47f714c9550..a048e3bb867bd 100644 --- a/pandas/tests/io/formats/test_format.py +++ b/pandas/tests/io/formats/test_format.py @@ -7,6 +7,7 @@ import itertools from operator import methodcaller import os +from pathlib import Path import re from shutil import get_terminal_size import sys @@ -17,7 +18,7 @@ import pytest import pytz -from pandas.compat import is_platform_32bit, is_platform_windows +from pandas.compat import PY36, is_platform_32bit, is_platform_windows import pandas as pd from pandas import ( @@ -42,6 +43,54 @@ use_32bit_repr = is_platform_windows() or is_platform_32bit() +@pytest.fixture(params=["string", "pathlike", "buffer"]) +def filepath_or_buffer_id(request): + """ + A fixture yielding test ids for filepath_or_buffer testing. + """ + return request.param + + +@pytest.fixture +def filepath_or_buffer(filepath_or_buffer_id, tmp_path): + """ + A fixture yeilding a string representing a filepath, a path-like object + and a StringIO buffer. Also checks that buffer is not closed. + """ + if filepath_or_buffer_id == "buffer": + buf = StringIO() + yield buf + assert not buf.closed + else: + if PY36: + assert isinstance(tmp_path, Path) + else: + assert hasattr(tmp_path, "__fspath__") + if filepath_or_buffer_id == "pathlike": + yield tmp_path / "foo" + else: + yield str(tmp_path / "foo") + + +@pytest.fixture +def assert_filepath_or_buffer_equals(filepath_or_buffer, filepath_or_buffer_id): + """ + Assertion helper for checking filepath_or_buffer. + """ + + def _assert_filepath_or_buffer_equals(expected): + if filepath_or_buffer_id == "string": + with open(filepath_or_buffer) as f: + result = f.read() + elif filepath_or_buffer_id == "pathlike": + result = filepath_or_buffer.read_text() + elif filepath_or_buffer_id == "buffer": + result = filepath_or_buffer.getvalue() + assert result == expected + + return _assert_filepath_or_buffer_equals + + def curpath(): pth, _ = os.path.split(os.path.abspath(__file__)) return pth @@ -3142,3 +3191,21 @@ def test_repr_html_ipython_config(ip): ) result = ip.run_cell(code) assert not result.error_in_exec + + +@pytest.mark.parametrize("method", ["to_string", "to_html", "to_latex"]) +def test_filepath_or_buffer_arg( + float_frame, method, filepath_or_buffer, assert_filepath_or_buffer_equals +): + df = float_frame + expected = getattr(df, method)() + + getattr(df, method)(buf=filepath_or_buffer) + assert_filepath_or_buffer_equals(expected) + + +@pytest.mark.parametrize("method", ["to_string", "to_html", "to_latex"]) +def test_filepath_or_buffer_bad_arg_raises(float_frame, method): + msg = "buf is not a file name and it has no write method" + with pytest.raises(TypeError, match=msg): + getattr(float_frame, method)(buf=object())