Skip to content

Commit f73002f

Browse files
authored
ENH: Improve remaining io (#252)
* ENH: Improve to_dict typing * ENH: Improve to_records typing * ENH: Improve xarray and dict * TYP: Restore usual case for to_dict * TYP: Restore add overloads for to_dict * TST: Add tests for final io funcs * TST: Add tests for final io funcs and final fixes * Use type rather than actual for 3.8
1 parent 260ed76 commit f73002f

File tree

7 files changed

+126
-37
lines changed

7 files changed

+126
-37
lines changed

pandas-stubs/_typing.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,9 @@ CSVEngine = Literal["c", "python", "pyarrow", "python-fwf"]
238238

239239
HDFCompLib = Literal["zlib", "lzo", "bzip2", "blosc"]
240240
ParquetEngine = Literal["auto", "pyarrow", "fastparquet"]
241+
FileWriteMode = Literal[
242+
"a", "w", "x", "at", "wt", "xt", "ab", "wb", "xb", "w+", "w+b", "a+", "a+b"
243+
]
241244
ColspaceArgType = str | int | Sequence[int | str] | Mapping[Hashable, str | int]
242245

243246
__all__ = ["npt", "type_t"]

pandas-stubs/core/frame.pyi

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ from pandas.core.window.rolling import (
3838
Rolling,
3939
Window,
4040
)
41+
import xarray as xr
4142

4243
from pandas._typing import (
4344
S1,
@@ -87,6 +88,7 @@ from pandas._typing import (
8788
XMLParsers,
8889
np_ndarray_bool,
8990
np_ndarray_str,
91+
npt,
9092
num,
9193
)
9294

@@ -228,15 +230,32 @@ class DataFrame(NDFrame, OpsMixin):
228230
@overload
229231
def to_dict(
230232
self,
231-
orient: Literal["records"],
232-
into: Hashable = ...,
233-
) -> list[dict[_str, Any]]: ...
233+
orient: Literal["dict", "list", "series", "split", "tight", "index"],
234+
into: Mapping | type[Mapping],
235+
) -> Mapping[Hashable, Any]: ...
234236
@overload
235237
def to_dict(
236238
self,
237239
orient: Literal["dict", "list", "series", "split", "tight", "index"] = ...,
238-
into: Hashable = ...,
239-
) -> dict[_str, Any]: ...
240+
*,
241+
into: Mapping | type[Mapping],
242+
) -> Mapping[Hashable, Any]: ...
243+
@overload
244+
def to_dict(
245+
self,
246+
orient: Literal["dict", "list", "series", "split", "tight", "index"] = ...,
247+
into: None = ...,
248+
) -> dict[Hashable, Any]: ...
249+
@overload
250+
def to_dict(
251+
self,
252+
orient: Literal["records"],
253+
into: Mapping | type[Mapping],
254+
) -> list[Mapping[Hashable, Any]]: ...
255+
@overload
256+
def to_dict(
257+
self, orient: Literal["records"], into: None = ...
258+
) -> list[dict[Hashable, Any]]: ...
240259
def to_gbq(
241260
self,
242261
destination_table: str,
@@ -258,8 +277,14 @@ class DataFrame(NDFrame, OpsMixin):
258277
def to_records(
259278
self,
260279
index: _bool = ...,
261-
columnDTypes: _str | dict | None = ...,
262-
indexDTypes: _str | dict | None = ...,
280+
column_dtypes: _str
281+
| npt.DTypeLike
282+
| Mapping[HashableT, npt.DTypeLike]
283+
| None = ...,
284+
index_dtypes: _str
285+
| npt.DTypeLike
286+
| Mapping[HashableT, npt.DTypeLike]
287+
| None = ...,
263288
) -> np.recarray: ...
264289
def to_stata(
265290
self,
@@ -279,12 +304,6 @@ class DataFrame(NDFrame, OpsMixin):
279304
) -> None: ...
280305
def to_feather(self, path: FilePath | WriteBuffer[bytes], **kwargs) -> None: ...
281306
@overload
282-
def to_markdown(
283-
self, buf: FilePathOrBuffer | None, mode: _str | None = ..., **kwargs
284-
) -> None: ...
285-
@overload
286-
def to_markdown(self, mode: _str | None = ..., **kwargs) -> _str: ...
287-
@overload
288307
def to_parquet(
289308
self,
290309
path: FilePath | WriteBuffer[bytes],
@@ -2038,7 +2057,7 @@ class DataFrame(NDFrame, OpsMixin):
20382057
max_colwidth: int | None = ...,
20392058
encoding: _str | None = ...,
20402059
) -> _str: ...
2041-
def to_xarray(self): ...
2060+
def to_xarray(self) -> xr.Dataset: ...
20422061
def truediv(
20432062
self,
20442063
other: num | ListLike | DataFrame,

pandas-stubs/core/generic.pyi

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ from pandas._typing import (
2222
Dtype,
2323
FilePath,
2424
FilePathOrBuffer,
25+
FileWriteMode,
2526
FillnaOptions,
2627
FrameOrSeries,
2728
FrameOrSeriesUnion,
@@ -34,6 +35,7 @@ from pandas._typing import (
3435
Scalar,
3536
SeriesAxisType,
3637
SortKind,
38+
StorageOptions,
3739
T,
3840
)
3941

@@ -129,6 +131,24 @@ class NDFrame(PandasObject, indexing.IndexingMixin):
129131
] = ...,
130132
encoding: _str = ...,
131133
) -> None: ...
134+
@overload
135+
def to_markdown(
136+
self,
137+
buf: FilePathOrBuffer,
138+
mode: FileWriteMode | None = ...,
139+
index: _bool = ...,
140+
storage_options: StorageOptions = ...,
141+
**kwargs: Any,
142+
) -> None: ...
143+
@overload
144+
def to_markdown(
145+
self,
146+
buf: None = ...,
147+
mode: FileWriteMode | None = ...,
148+
index: _bool = ...,
149+
storage_options: StorageOptions = ...,
150+
**kwargs: Any,
151+
) -> _str: ...
132152
def to_sql(
133153
self,
134154
name: _str,
@@ -150,7 +170,6 @@ class NDFrame(PandasObject, indexing.IndexingMixin):
150170
def to_clipboard(
151171
self, excel: _bool = ..., sep: _str | None = ..., **kwargs
152172
) -> None: ...
153-
def to_xarray(self): ...
154173
@overload
155174
def to_latex(
156175
self,

pandas-stubs/core/series.pyi

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ from pandas.core.window.rolling import (
5454
Rolling,
5555
Window,
5656
)
57+
import xarray as xr
5758

5859
from pandas._typing import (
5960
S1,
@@ -354,22 +355,6 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
354355
encoding: _str | None = ...,
355356
) -> _str: ...
356357
@overload
357-
def to_markdown(
358-
self,
359-
buf: FilePathOrBuffer | None,
360-
mode: _str | None = ...,
361-
index: _bool = ...,
362-
storage_options: dict | None = ...,
363-
**kwargs,
364-
) -> None: ...
365-
@overload
366-
def to_markdown(
367-
self,
368-
mode: _str | None = ...,
369-
index: _bool = ...,
370-
storage_options: dict | None = ...,
371-
) -> _str: ...
372-
@overload
373358
def to_json(
374359
self,
375360
path_or_buf: FilePathOrBuffer | None,
@@ -400,10 +385,14 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
400385
index: _bool = ...,
401386
indent: int | None = ...,
402387
) -> _str: ...
388+
def to_xarray(self) -> xr.DataArray: ...
403389
def items(self) -> Iterable[tuple[Hashable, S1]]: ...
404390
def iteritems(self) -> Iterable[tuple[Label, S1]]: ...
405391
def keys(self) -> list: ...
406-
def to_dict(self, into: Hashable = ...) -> dict[Any, S1]: ...
392+
@overload
393+
def to_dict(self) -> dict[Hashable, S1]: ...
394+
@overload
395+
def to_dict(self, into: type[Mapping] | Mapping) -> Mapping[Hashable, S1]: ...
407396
def to_frame(self, name: object | None = ...) -> DataFrame: ...
408397
@overload
409398
def groupby(

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ openpyxl = ">=3.0.10"
5151
tables = ">=3.7.0"
5252
lxml = ">=4.7.1,<4.9.0"
5353
pyreadstat = ">=1.1.9"
54+
xarray = ">=22.6.0"
55+
tabulate = ">=0.8.10"
5456

5557
[build-system]
5658
requires = ["poetry-core>=1.0.0"]

tests/test_frame.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
from __future__ import annotations
22

3+
from collections import defaultdict
34
import datetime
45
import io
56
from pathlib import Path
67
from typing import (
78
TYPE_CHECKING,
89
Any,
910
Callable,
11+
Dict,
1012
Generic,
1113
Hashable,
1214
Iterable,
1315
Iterator,
16+
List,
17+
Mapping,
1418
Tuple,
1519
TypeVar,
1620
Union,
@@ -24,13 +28,16 @@
2428
)
2529
import pytest
2630
from typing_extensions import assert_type
31+
import xarray as xr
2732

2833
from pandas._typing import Scalar
2934

3035
from tests import check
3136

3237
from pandas.io.parsers import TextFileReader
3338

39+
DF = pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]})
40+
3441

3542
def test_types_init() -> None:
3643
pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]})
@@ -777,12 +784,18 @@ def test_types_to_numpy() -> None:
777784

778785

779786
def test_to_markdown() -> None:
780-
pytest.importorskip("tabulate")
781787
df = pd.DataFrame(data={"col1": [1, 1, 2], "col2": [3, 4, 5]})
782-
df.to_markdown()
783-
df.to_markdown(buf=None, mode="wt")
788+
check(assert_type(df.to_markdown(), str), str)
789+
check(assert_type(df.to_markdown(None), str), str)
790+
check(assert_type(df.to_markdown(buf=None, mode="wt"), str), str)
784791
# index param was added in 1.1.0 https://pandas.pydata.org/docs/whatsnew/v1.1.0.html
785-
df.to_markdown(index=False)
792+
check(assert_type(df.to_markdown(index=False), str), str)
793+
with ensure_clean() as path:
794+
check(assert_type(df.to_markdown(path), None), type(None))
795+
with ensure_clean() as path:
796+
check(assert_type(df.to_markdown(Path(path)), None), type(None))
797+
sio = io.StringIO()
798+
check(assert_type(df.to_markdown(sio), None), type(None))
786799

787800

788801
def test_types_to_feather() -> None:
@@ -1687,6 +1700,43 @@ def func() -> MyDataFrame[int]:
16871700
func()
16881701

16891702

1703+
def test_to_xarray():
1704+
check(assert_type(DF.to_xarray(), xr.Dataset), xr.Dataset)
1705+
1706+
1707+
def test_to_records():
1708+
check(assert_type(DF.to_records(False, "int8"), np.recarray), np.recarray)
1709+
check(
1710+
assert_type(DF.to_records(False, index_dtypes=np.int8), np.recarray),
1711+
np.recarray,
1712+
)
1713+
check(
1714+
assert_type(
1715+
DF.to_records(False, {"col1": np.int8, "col2": np.int16}), np.recarray
1716+
),
1717+
np.recarray,
1718+
)
1719+
1720+
1721+
def test_to_dict():
1722+
check(assert_type(DF.to_dict(), Dict[Hashable, Any]), dict)
1723+
check(assert_type(DF.to_dict("split"), Dict[Hashable, Any]), dict)
1724+
1725+
target: Mapping = defaultdict(list)
1726+
check(assert_type(DF.to_dict(into=target), Mapping[Hashable, Any]), defaultdict)
1727+
target = defaultdict(list)
1728+
check(
1729+
assert_type(DF.to_dict("tight", into=target), Mapping[Hashable, Any]),
1730+
defaultdict,
1731+
)
1732+
target = defaultdict(list)
1733+
check(assert_type(DF.to_dict("records"), List[Dict[Hashable, Any]]), list)
1734+
check(
1735+
assert_type(DF.to_dict("records", into=target), List[Mapping[Hashable, Any]]),
1736+
list,
1737+
)
1738+
1739+
16901740
def test_neg() -> None:
16911741
# GH 253
16921742
df = pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]})

tests/test_series.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
TYPE_CHECKING,
88
Any,
99
Dict,
10+
Hashable,
1011
Iterable,
1112
Iterator,
1213
List,
@@ -22,6 +23,7 @@
2223
from pandas.core.window import ExponentialMovingWindow
2324
import pytest
2425
from typing_extensions import assert_type
26+
import xarray as xr
2527

2628
from pandas._typing import Scalar
2729

@@ -909,7 +911,7 @@ def test_types_to_list() -> None:
909911

910912
def test_types_to_dict() -> None:
911913
s = pd.Series(["a", "b", "c"], dtype=str)
912-
assert_type(s.to_dict(), Dict[Any, str])
914+
assert_type(s.to_dict(), Dict[Hashable, str])
913915

914916

915917
def test_categorical_codes():
@@ -1126,6 +1128,11 @@ def test_resample() -> None:
11261128
check(assert_type(df.resample("2T").ohlc(), pd.DataFrame), pd.DataFrame)
11271129

11281130

1131+
def test_to_xarray():
1132+
s = pd.Series([1, 2])
1133+
check(assert_type(s.to_xarray(), xr.DataArray), xr.DataArray)
1134+
1135+
11291136
def test_neg() -> None:
11301137
# GH 253
11311138
sr = pd.Series([1, 2, 3])

0 commit comments

Comments
 (0)