Skip to content

Commit 5f22df7

Browse files
Skn0ttMargarete01
andauthored
feat: add support for writing tar files
co-authored-by: Margarete Dippel <margarete01@users.noreply.github.com>
1 parent d4e40c9 commit 5f22df7

File tree

10 files changed

+142
-19
lines changed

10 files changed

+142
-19
lines changed

pandas/_testing/_io.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import bz2
44
from functools import wraps
55
import gzip
6+
import io
7+
import tarfile
68
from typing import (
79
TYPE_CHECKING,
810
Any,
@@ -387,6 +389,14 @@ def write_to_compressed(compression, path, data, dest="test"):
387389
mode = "w"
388390
args = (dest, data)
389391
method = "writestr"
392+
elif compression == "tar":
393+
compress_method = tarfile.TarFile
394+
mode = "w"
395+
file = tarfile.TarInfo(name=dest)
396+
bytes = io.BytesIO(data)
397+
file.size = len(data)
398+
args = (file, bytes)
399+
method = "addfile"
390400
elif compression == "gzip":
391401
compress_method = gzip.GzipFile
392402
elif compression == "bz2":

pandas/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,15 +267,15 @@ def other_closed(request):
267267
return request.param
268268

269269

270-
@pytest.fixture(params=[None, "gzip", "bz2", "zip", "xz"])
270+
@pytest.fixture(params=[None, "gzip", "bz2", "zip", "xz", "tar"])
271271
def compression(request):
272272
"""
273273
Fixture for trying common compression types in compression tests.
274274
"""
275275
return request.param
276276

277277

278-
@pytest.fixture(params=["gzip", "bz2", "zip", "xz"])
278+
@pytest.fixture(params=["gzip", "bz2", "zip", "xz", "tar"])
279279
def compression_only(request):
280280
"""
281281
Fixture for trying common compression types in compression tests excluding

pandas/core/generic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2341,6 +2341,7 @@ def to_json(
23412341
default_handler: Callable[[Any], JSONSerializable] | None = None,
23422342
lines: bool_t = False,
23432343
compression: CompressionOptions = "infer",
2344+
mode: str = "w",
23442345
index: bool_t = True,
23452346
indent: int | None = None,
23462347
storage_options: StorageOptions = None,
@@ -2604,6 +2605,7 @@ def to_json(
26042605
default_handler=default_handler,
26052606
lines=lines,
26062607
compression=compression,
2608+
mode=mode,
26072609
index=index,
26082610
indent=indent,
26092611
storage_options=storage_options,
@@ -2923,6 +2925,7 @@ def to_pickle(
29232925
self,
29242926
path,
29252927
compression: CompressionOptions = "infer",
2928+
mode: str = "wb",
29262929
protocol: int = pickle.HIGHEST_PROTOCOL,
29272930
storage_options: StorageOptions = None,
29282931
) -> None:
@@ -2990,6 +2993,7 @@ def to_pickle(
29902993
self,
29912994
path,
29922995
compression=compression,
2996+
mode=mode,
29932997
protocol=protocol,
29942998
storage_options=storage_options,
29952999
)

pandas/io/common.py

Lines changed: 90 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from io import (
1111
BufferedIOBase,
1212
BytesIO,
13+
FileIO,
1314
RawIOBase,
1415
StringIO,
1516
TextIOBase,
@@ -758,18 +759,22 @@ def get_handle(
758759

759760
# TAR Encoding
760761
elif compression == "tar":
761-
tar = tarfile.open(handle, "r:*")
762-
handles.append(tar)
763-
files = tar.getnames()
764-
if len(files) == 1:
765-
handle = tar.extractfile(files[0])
766-
elif len(files) == 0:
767-
raise ValueError(f"Zero files found in TAR archive {path_or_buf}")
762+
if is_path:
763+
handle = _BytesTarFile.open(name=handle, mode=ioargs.mode)
768764
else:
769-
raise ValueError(
770-
"Multiple files found in TAR archive. "
771-
f"Only one file per TAR archive: {files}"
772-
)
765+
handle = _BytesTarFile.open(fileobj=handle, mode=ioargs.mode)
766+
if handle.mode == "r":
767+
handles.append(handle)
768+
files = handle.getnames()
769+
if len(files) == 1:
770+
handle = handle.extractfile(files[0])
771+
elif len(files) == 0:
772+
raise ValueError(f"Zero files found in TAR archive {path_or_buf}")
773+
else:
774+
raise ValueError(
775+
"Multiple files found in TAR archive. "
776+
f"Only one file per TAR archive: {files}"
777+
)
773778

774779
# XZ Compression
775780
elif compression == "xz":
@@ -852,6 +857,80 @@ def get_handle(
852857
)
853858

854859

860+
class _BytesTarFile(tarfile.TarFile, BytesIO):
861+
862+
# GH 17778
863+
def __init__(
864+
self,
865+
name: FilePath | ReadBuffer[bytes] | WriteBuffer[bytes],
866+
mode: str,
867+
fileobj: FileIO,
868+
archive_name: str | None = None,
869+
**kwargs,
870+
):
871+
self.archive_name = archive_name
872+
self.multiple_write_buffer: StringIO | BytesIO | None = None
873+
self._closing = False
874+
875+
super().__init__(name=name, mode=mode, fileobj=fileobj, **kwargs)
876+
877+
@classmethod
878+
def open(cls, mode="r", **kwargs):
879+
mode = mode.replace("b", "")
880+
return super().open(mode=mode, **kwargs)
881+
882+
def infer_filename(self):
883+
"""
884+
If an explicit archive_name is not given, we still want the file inside the zip
885+
file not to be named something.tar, because that causes confusion (GH39465).
886+
"""
887+
if isinstance(self.name, (os.PathLike, str)):
888+
filename = Path(self.name)
889+
if filename.suffix == ".tar":
890+
return filename.with_suffix("").name
891+
if filename.suffix in [".tar.gz", ".tar.bz2", ".tar.xz"]:
892+
return filename.with_suffix("").with_suffix("").name
893+
return filename.name
894+
return None
895+
896+
def write(self, data):
897+
# buffer multiple write calls, write on flush
898+
if self.multiple_write_buffer is None:
899+
self.multiple_write_buffer = (
900+
BytesIO() if isinstance(data, bytes) else StringIO()
901+
)
902+
self.multiple_write_buffer.write(data)
903+
904+
def flush(self) -> None:
905+
# write to actual handle and close write buffer
906+
if self.multiple_write_buffer is None or self.multiple_write_buffer.closed:
907+
return
908+
909+
# TarFile needs a non-empty string
910+
archive_name = self.archive_name or self.infer_filename() or "tar"
911+
with self.multiple_write_buffer:
912+
value = self.multiple_write_buffer.getvalue()
913+
tarinfo = tarfile.TarInfo(name=archive_name)
914+
tarinfo.size = len(value)
915+
self.addfile(tarinfo, io.BytesIO(value))
916+
917+
def close(self):
918+
self.flush()
919+
super().close()
920+
921+
@property
922+
def closed(self):
923+
if self.multiple_write_buffer is None:
924+
return False
925+
return self.multiple_write_buffer.closed and super().closed
926+
927+
@closed.setter
928+
def closed(self, value):
929+
if not self._closing and value:
930+
self._closing = True
931+
self.close()
932+
933+
855934
# error: Definition of "__exit__" in base class "ZipFile" is incompatible with
856935
# definition in base class "BytesIO" [misc]
857936
# error: Definition of "__enter__" in base class "ZipFile" is incompatible with

pandas/io/json/_json.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def to_json(
8383
default_handler: Callable[[Any], JSONSerializable] | None = None,
8484
lines: bool = False,
8585
compression: CompressionOptions = "infer",
86+
mode: str = "w",
8687
index: bool = True,
8788
indent: int = 0,
8889
storage_options: StorageOptions = None,
@@ -127,7 +128,7 @@ def to_json(
127128
if path_or_buf is not None:
128129
# apply compression and byte/text conversion
129130
with get_handle(
130-
path_or_buf, "w", compression=compression, storage_options=storage_options
131+
path_or_buf, mode, compression=compression, storage_options=storage_options
131132
) as handles:
132133
handles.handle.write(s)
133134
else:

pandas/io/pickle.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def to_pickle(
2525
obj: Any,
2626
filepath_or_buffer: FilePath | WriteBuffer[bytes],
2727
compression: CompressionOptions = "infer",
28+
mode: str = "wb",
2829
protocol: int = pickle.HIGHEST_PROTOCOL,
2930
storage_options: StorageOptions = None,
3031
):
@@ -95,7 +96,7 @@ def to_pickle(
9596

9697
with get_handle(
9798
filepath_or_buffer,
98-
"wb",
99+
mode,
99100
compression=compression,
100101
is_text=False,
101102
storage_options=storage_options,

pandas/tests/io/test_compression.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
import pandas.io.common as icom
1515

1616

17+
def flip(my_dict: dict):
18+
return {value: key for key, value in my_dict.items()}
19+
20+
1721
@pytest.mark.parametrize(
1822
"obj",
1923
[
@@ -26,8 +30,13 @@
2630
)
2731
@pytest.mark.parametrize("method", ["to_pickle", "to_json", "to_csv"])
2832
def test_compression_size(obj, method, compression_only):
33+
kwargs = {}
34+
35+
if compression_only == "tar":
36+
kwargs["mode"] = "w:gz"
37+
2938
with tm.ensure_clean() as path:
30-
getattr(obj, method)(path, compression=compression_only)
39+
getattr(obj, method)(path, compression=compression_only, **kwargs)
3140
compressed_size = os.path.getsize(path)
3241
getattr(obj, method)(path, compression=None)
3342
uncompressed_size = os.path.getsize(path)
@@ -72,7 +81,7 @@ def test_dataframe_compression_defaults_to_infer(
7281
):
7382
# GH22004
7483
input = pd.DataFrame([[1.0, 0, -4], [3.4, 5, 2]], columns=["X", "Y", "Z"])
75-
extension = icom._compression_to_extension[compression_only]
84+
extension = flip(icom._extension_to_compression)[compression_only]
7685
with tm.ensure_clean("compressed" + extension) as path:
7786
getattr(input, write_method)(path, **write_kwargs)
7887
output = read_method(path, compression=compression_only)
@@ -92,7 +101,7 @@ def test_series_compression_defaults_to_infer(
92101
):
93102
# GH22004
94103
input = pd.Series([0, 5, -2, 10], name="X")
95-
extension = icom._compression_to_extension[compression_only]
104+
extension = flip(icom._extension_to_compression)[compression_only]
96105
with tm.ensure_clean("compressed" + extension) as path:
97106
getattr(input, write_method)(path, **write_kwargs)
98107
if "squeeze" in read_kwargs:

pandas/tests/io/test_gcs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from io import BytesIO
22
import os
3+
import tarfile
34
import zipfile
45

56
import numpy as np
@@ -104,6 +105,14 @@ def assert_equal_zip_safe(result: bytes, expected: bytes, compression: str):
104105
) as res:
105106
for res_info, exp_info in zip(res.infolist(), exp.infolist()):
106107
assert res_info.CRC == exp_info.CRC
108+
elif compression == "tar":
109+
with tarfile.open(fileobj=BytesIO(result)) as exp, tarfile.open(
110+
fileobj=BytesIO(expected)
111+
) as res:
112+
for res_info, exp_info in zip(res.getmembers(), exp.getmembers()):
113+
assert (
114+
res.extractfile(res_info).read() == exp.extractfile(exp_info).read()
115+
)
107116
else:
108117
assert result == expected
109118

pandas/tests/io/test_pickle.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from pathlib import Path
2222
import pickle
2323
import shutil
24+
import tarfile
2425
from warnings import (
2526
catch_warnings,
2627
filterwarnings,
@@ -306,13 +307,18 @@ def compress_file(self, src_path, dest_path, compression):
306307
elif compression == "zip":
307308
with zipfile.ZipFile(dest_path, "w", compression=zipfile.ZIP_DEFLATED) as f:
308309
f.write(src_path, os.path.basename(src_path))
310+
elif compression == "tar":
311+
with open(src_path, "rb") as fh:
312+
with tarfile.open(dest_path, mode="w") as tar:
313+
tarinfo = tar.gettarinfo(src_path, os.path.basename(src_path))
314+
tar.addfile(tarinfo, fh)
309315
elif compression == "xz":
310316
f = get_lzma_file()(dest_path, "w")
311317
else:
312318
msg = f"Unrecognized compression type: {compression}"
313319
raise ValueError(msg)
314320

315-
if compression != "zip":
321+
if compression not in ["zip", "tar"]:
316322
with open(src_path, "rb") as fh, f:
317323
f.write(fh.read())
318324

pandas/tests/io/test_stata.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import io
66
import os
77
import struct
8+
import tarfile
89
import warnings
910
import zipfile
1011

@@ -1899,6 +1900,9 @@ def test_compression(compression, version, use_dict, infer):
18991900
elif compression == "zip":
19001901
with zipfile.ZipFile(path, "r") as comp:
19011902
fp = io.BytesIO(comp.read(comp.filelist[0]))
1903+
elif compression == "tar":
1904+
with tarfile.open(path) as tar:
1905+
fp = io.BytesIO(tar.extractfile(tar.getnames()[0]).read())
19021906
elif compression == "bz2":
19031907
with bz2.open(path, "rb") as comp:
19041908
fp = io.BytesIO(comp.read())

0 commit comments

Comments
 (0)