Skip to content

Commit e611e29

Browse files
authored
Enable public decoder creation with file like object (#616)
1 parent bea7360 commit e611e29

File tree

5 files changed

+70
-10
lines changed

5 files changed

+70
-10
lines changed

src/torchcodec/_core/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def create_from_bytes(
147147

148148

149149
def create_from_file_like(
150-
file_like: Union[io.RawIOBase, io.BytesIO], seek_mode: Optional[str] = None
150+
file_like: Union[io.RawIOBase, io.BufferedReader], seek_mode: Optional[str] = None
151151
) -> torch.Tensor:
152152
assert _pybind_ops is not None
153153
return _convert_to_tensor(_pybind_ops.create_from_file_like(file_like, seek_mode))

src/torchcodec/decoders/_audio_decoder.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import io
78
from pathlib import Path
89
from typing import Optional, Union
910

@@ -25,11 +26,14 @@ class AudioDecoder:
2526
Returned samples are float samples normalized in [-1, 1]
2627
2728
Args:
28-
source (str, ``Pathlib.path``, ``torch.Tensor``, or bytes): The source of the audio:
29+
source (str, ``Pathlib.path``, bytes, ``torch.Tensor`` or file-like object): The source of the video:
2930
3031
- If ``str``: a local path or a URL to a video or audio file.
3132
- If ``Pathlib.path``: a path to a local video or audio file.
3233
- If ``bytes`` object or ``torch.Tensor``: the raw encoded audio data.
34+
- If file-like object: we read video data from the object on demand. The object must
35+
expose the methods ``read(self, size: int) -> bytes`` and
36+
``seek(self, offset: int, whence: int) -> bytes``. Read more in TODO_FILE_LIKE_TUTORIAL.
3337
stream_index (int, optional): Specifies which stream in the file to decode samples from.
3438
Note that this index is absolute across all media types. If left unspecified, then
3539
the :term:`best stream` is used.
@@ -45,7 +49,7 @@ class AudioDecoder:
4549

4650
def __init__(
4751
self,
48-
source: Union[str, Path, bytes, Tensor],
52+
source: Union[str, Path, io.RawIOBase, io.BufferedReader, bytes, Tensor],
4953
*,
5054
stream_index: Optional[int] = None,
5155
sample_rate: Optional[int] = None,

src/torchcodec/decoders/_decoder_utils.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import io
78
from pathlib import Path
89

910
from typing import Union
@@ -18,18 +19,34 @@
1819

1920

2021
def create_decoder(
21-
*, source: Union[str, Path, bytes, Tensor], seek_mode: str
22+
*,
23+
source: Union[str, Path, io.RawIOBase, io.BufferedReader, bytes, Tensor],
24+
seek_mode: str,
2225
) -> Tensor:
2326
if isinstance(source, str):
2427
return core.create_from_file(source, seek_mode)
2528
elif isinstance(source, Path):
2629
return core.create_from_file(str(source), seek_mode)
30+
elif isinstance(source, io.RawIOBase) or isinstance(source, io.BufferedReader):
31+
return core.create_from_file_like(source, seek_mode)
2732
elif isinstance(source, bytes):
2833
return core.create_from_bytes(source, seek_mode)
2934
elif isinstance(source, Tensor):
3035
return core.create_from_tensor(source, seek_mode)
36+
elif isinstance(source, io.TextIOBase):
37+
raise TypeError(
38+
"source is for reading text, likely from open(..., 'r'). Try with 'rb' for binary reading?"
39+
)
40+
elif hasattr(source, "read") and hasattr(source, "seek"):
41+
# This check must be after checking for text-based reading. Also placing
42+
# it last in general to be defensive: hasattr is a blunt instrument. We
43+
# could use the inspect module to check for methods with the right
44+
# signature.
45+
return core.create_from_file_like(source, seek_mode)
3146

3247
raise TypeError(
3348
f"Unknown source type: {type(source)}. "
34-
"Supported types are str, Path, bytes and Tensor."
49+
"Supported types are str, Path, bytes, Tensor and file-like objects with "
50+
"read(self, size: int) -> bytes and "
51+
"seek(self, offset: int, whence: int) -> bytes methods."
3552
)

src/torchcodec/decoders/_video_decoder.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import io
78
import numbers
89
from pathlib import Path
910
from typing import Literal, Optional, Tuple, Union
@@ -21,11 +22,14 @@ class VideoDecoder:
2122
"""A single-stream video decoder.
2223
2324
Args:
24-
source (str, ``Pathlib.path``, ``torch.Tensor``, or bytes): The source of the video.
25+
source (str, ``Pathlib.path``, bytes, ``torch.Tensor`` or file-like object): The source of the video:
2526
2627
- If ``str``: a local path or a URL to a video file.
2728
- If ``Pathlib.path``: a path to a local video file.
2829
- If ``bytes`` object or ``torch.Tensor``: the raw encoded video data.
30+
- If file-like object: we read video data from the object on demand. The object must
31+
expose the methods ``read(self, size: int) -> bytes`` and
32+
``seek(self, offset: int, whence: int) -> bytes``. Read more in TODO_FILE_LIKE_TUTORIAL.
2933
stream_index (int, optional): Specifies which stream in the video to decode frames from.
3034
Note that this index is absolute across all media types. If left unspecified, then
3135
the :term:`best stream` is used.
@@ -66,7 +70,7 @@ class VideoDecoder:
6670

6771
def __init__(
6872
self,
69-
source: Union[str, Path, bytes, Tensor],
73+
source: Union[str, Path, io.RawIOBase, io.BufferedReader, bytes, Tensor],
7074
*,
7175
stream_index: Optional[int] = None,
7276
dimension_order: Literal["NCHW", "NHWC"] = "NCHW",

test/test_decoders.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,48 @@ class TestDecoder:
4545
(AudioDecoder, NASA_AUDIO_MP3),
4646
),
4747
)
48-
@pytest.mark.parametrize("source_kind", ("str", "path", "tensor", "bytes"))
48+
@pytest.mark.parametrize(
49+
"source_kind",
50+
(
51+
"str",
52+
"path",
53+
"file_like_rawio",
54+
"file_like_bufferedio",
55+
"file_like_custom",
56+
"bytes",
57+
"tensor",
58+
),
59+
)
4960
def test_create(self, Decoder, asset, source_kind):
5061
if source_kind == "str":
5162
source = str(asset.path)
5263
elif source_kind == "path":
5364
source = asset.path
54-
elif source_kind == "tensor":
55-
source = asset.to_tensor()
65+
elif source_kind == "file_like_rawio":
66+
source = open(asset.path, mode="rb", buffering=0)
67+
elif source_kind == "file_like_bufferedio":
68+
source = open(asset.path, mode="rb", buffering=4096)
69+
elif source_kind == "file_like_custom":
70+
# This class purposefully does not inherit from io.RawIOBase or
71+
# io.BufferedReader. We are testing the case when users pass an
72+
# object that has the right methods but is an arbitrary type.
73+
class CustomReader:
74+
def __init__(self, file):
75+
self._file = file
76+
77+
def read(self, size: int) -> bytes:
78+
return self._file.read(size)
79+
80+
def seek(self, offset: int, whence: int) -> bytes:
81+
return self._file.seek(offset, whence)
82+
83+
source = CustomReader(open(asset.path, mode="rb", buffering=0))
5684
elif source_kind == "bytes":
5785
path = str(asset.path)
5886
with open(path, "rb") as f:
5987
source = f.read()
88+
elif source_kind == "tensor":
89+
source = asset.to_tensor()
6090
else:
6191
raise ValueError("Oops, double check the parametrization of this test!")
6292

@@ -76,6 +106,11 @@ def test_create_fails(self, Decoder):
76106
with pytest.raises(ValueError, match="No valid stream found"):
77107
Decoder(NASA_VIDEO.path, stream_index=2)
78108

109+
# user mistakenly forgets to specify binary reading when creating a file
110+
# like object from open()
111+
with pytest.raises(TypeError, match="binary reading?"):
112+
Decoder(open(NASA_VIDEO.path, "r"))
113+
79114

80115
class TestVideoDecoder:
81116
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))

0 commit comments

Comments
 (0)