Skip to content

Enable public decoder creation with file like object #616

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/torchcodec/_core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def create_from_bytes(


def create_from_file_like(
file_like: Union[io.RawIOBase, io.BytesIO], seek_mode: Optional[str] = None
file_like: Union[io.RawIOBase, io.BufferedReader], seek_mode: Optional[str] = None
) -> torch.Tensor:
assert _pybind_ops is not None
return _convert_to_tensor(_pybind_ops.create_from_file_like(file_like, seek_mode))
Expand Down
8 changes: 6 additions & 2 deletions src/torchcodec/decoders/_audio_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import io
from pathlib import Path
from typing import Optional, Union

Expand All @@ -25,11 +26,14 @@ class AudioDecoder:
Returned samples are float samples normalized in [-1, 1]

Args:
source (str, ``Pathlib.path``, ``torch.Tensor``, or bytes): The source of the audio:
source (str, ``Pathlib.path``, bytes, ``torch.Tensor`` or file-like object): The source of the video:

- If ``str``: a local path or a URL to a video or audio file.
- If ``Pathlib.path``: a path to a local video or audio file.
- If ``bytes`` object or ``torch.Tensor``: the raw encoded audio data.
- If file-like object: we read video data from the object on demand. The object must
expose the methods ``read(self, size: int) -> bytes`` and
``seek(self, offset: int, whence: int) -> bytes``. Read more in TODO_FILE_LIKE_TUTORIAL.
stream_index (int, optional): Specifies which stream in the file to decode samples from.
Note that this index is absolute across all media types. If left unspecified, then
the :term:`best stream` is used.
Expand All @@ -45,7 +49,7 @@ class AudioDecoder:

def __init__(
self,
source: Union[str, Path, bytes, Tensor],
source: Union[str, Path, io.RawIOBase, io.BufferedReader, bytes, Tensor],
*,
stream_index: Optional[int] = None,
sample_rate: Optional[int] = None,
Expand Down
21 changes: 19 additions & 2 deletions src/torchcodec/decoders/_decoder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import io
from pathlib import Path

from typing import Union
Expand All @@ -18,18 +19,34 @@


def create_decoder(
*, source: Union[str, Path, bytes, Tensor], seek_mode: str
*,
source: Union[str, Path, io.RawIOBase, io.BufferedReader, bytes, Tensor],
seek_mode: str,
) -> Tensor:
if isinstance(source, str):
return core.create_from_file(source, seek_mode)
elif isinstance(source, Path):
return core.create_from_file(str(source), seek_mode)
elif isinstance(source, io.RawIOBase) or isinstance(source, io.BufferedReader):
return core.create_from_file_like(source, seek_mode)
elif isinstance(source, bytes):
return core.create_from_bytes(source, seek_mode)
elif isinstance(source, Tensor):
return core.create_from_tensor(source, seek_mode)
elif isinstance(source, io.TextIOBase):
raise TypeError(
"source is for reading text, likely from open(..., 'r'). Try with 'rb' for binary reading?"
)
elif hasattr(source, "read") and hasattr(source, "seek"):
# This check must be after checking for text-based reading. Also placing
# it last in general to be defensive: hasattr is a blunt instrument. We
# could use the inspect module to check for methods with the right
# signature.
return core.create_from_file_like(source, seek_mode)

raise TypeError(
f"Unknown source type: {type(source)}. "
"Supported types are str, Path, bytes and Tensor."
"Supported types are str, Path, bytes, Tensor and file-like objects with "
"read(self, size: int) -> bytes and "
"seek(self, offset: int, whence: int) -> bytes methods."
)
8 changes: 6 additions & 2 deletions src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import io
import numbers
from pathlib import Path
from typing import Literal, Optional, Tuple, Union
Expand All @@ -21,11 +22,14 @@ class VideoDecoder:
"""A single-stream video decoder.

Args:
source (str, ``Pathlib.path``, ``torch.Tensor``, or bytes): The source of the video.
source (str, ``Pathlib.path``, bytes, ``torch.Tensor`` or file-like object): The source of the video:

- If ``str``: a local path or a URL to a video file.
- If ``Pathlib.path``: a path to a local video file.
- If ``bytes`` object or ``torch.Tensor``: the raw encoded video data.
- If file-like object: we read video data from the object on demand. The object must
expose the methods ``read(self, size: int) -> bytes`` and
``seek(self, offset: int, whence: int) -> bytes``. Read more in TODO_FILE_LIKE_TUTORIAL.
stream_index (int, optional): Specifies which stream in the video to decode frames from.
Note that this index is absolute across all media types. If left unspecified, then
the :term:`best stream` is used.
Expand Down Expand Up @@ -66,7 +70,7 @@ class VideoDecoder:

def __init__(
self,
source: Union[str, Path, bytes, Tensor],
source: Union[str, Path, io.RawIOBase, io.BufferedReader, bytes, Tensor],
*,
stream_index: Optional[int] = None,
dimension_order: Literal["NCHW", "NHWC"] = "NCHW",
Expand Down
41 changes: 38 additions & 3 deletions test/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,48 @@ class TestDecoder:
(AudioDecoder, NASA_AUDIO_MP3),
),
)
@pytest.mark.parametrize("source_kind", ("str", "path", "tensor", "bytes"))
@pytest.mark.parametrize(
"source_kind",
(
"str",
"path",
"file_like_rawio",
"file_like_bufferedio",
"file_like_custom",
"bytes",
"tensor",
),
)
def test_create(self, Decoder, asset, source_kind):
if source_kind == "str":
source = str(asset.path)
elif source_kind == "path":
source = asset.path
elif source_kind == "tensor":
source = asset.to_tensor()
elif source_kind == "file_like_rawio":
source = open(asset.path, mode="rb", buffering=0)
elif source_kind == "file_like_bufferedio":
source = open(asset.path, mode="rb", buffering=4096)
elif source_kind == "file_like_custom":
# This class purposefully does not inherit from io.RawIOBase or
# io.BufferedReader. We are testing the case when users pass an
# object that has the right methods but is an arbitrary type.
class CustomReader:
def __init__(self, file):
self._file = file

def read(self, size: int) -> bytes:
return self._file.read(size)

def seek(self, offset: int, whence: int) -> bytes:
return self._file.seek(offset, whence)

source = CustomReader(open(asset.path, mode="rb", buffering=0))
elif source_kind == "bytes":
path = str(asset.path)
with open(path, "rb") as f:
source = f.read()
elif source_kind == "tensor":
source = asset.to_tensor()
else:
raise ValueError("Oops, double check the parametrization of this test!")

Expand All @@ -76,6 +106,11 @@ def test_create_fails(self, Decoder):
with pytest.raises(ValueError, match="No valid stream found"):
Decoder(NASA_VIDEO.path, stream_index=2)

# user mistakenly forgets to specify binary reading when creating a file
# like object from open()
with pytest.raises(TypeError, match="binary reading?"):
Decoder(open(NASA_VIDEO.path, "r"))


class TestVideoDecoder:
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
Expand Down
Loading