Skip to content

Add AudioDecoder.get_samples_played_in_range() public method #555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/torchcodec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# Note: usort wants to put Frame and FrameBatch after decoders and samplers,
# but that results in circular import.
from ._frame import Frame, FrameBatch # usort:skip # noqa
from ._frame import AudioSamples, Frame, FrameBatch # usort:skip # noqa
from . import decoders, samplers # noqa

try:
Expand Down
27 changes: 26 additions & 1 deletion src/torchcodec/_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def _frame_repr(self):
# Utility to replace Frame and FrameBatch __repr__ method. This prints the
# Utility to replace __repr__ method of dataclasses below. This prints the
# shape of the .data tensor rather than printing the (potentially very long)
# data tensor itself.
s = self.__class__.__name__ + ":\n"
Expand Down Expand Up @@ -114,3 +114,28 @@ def __len__(self):

def __repr__(self):
return _frame_repr(self)


Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Below: the audio decoding API returns the new AudioSamples class rather than a pure Tensor. I think it has the following benefits:

  • users can keep track of the sample_rate within that struct without having to handle it separately
  • for some edge-cases, like with our mp3 test asset, the stream's beginning isn't 0. So we also returns pts_seconds, which may not always be equal to start_seconds.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this makes sense, and it mirrors the video API. In the video API, __getitem__() returns just the tensor, but the named methods return Frame or FrameBatch. I think we should probably do that for audio as well.

@dataclass
class AudioSamples(Iterable):
"""Audio samples with associated metadata."""

# TODO-AUDIO: docs
data: Tensor
pts_seconds: float
sample_rate: int

def __post_init__(self):
# This is called after __init__() when a Frame is created. We can run
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# This is called after __init__() when a Frame is created. We can run
# This is called after __init__() when an AudioSamples instance is created. We can run

# input validation checks here.
if not self.data.ndim == 2:
raise ValueError(f"data must be 2-dimensional, got {self.data.shape = }")
self.pts_seconds = float(self.pts_seconds)
self.sample_rate = int(self.sample_rate)

def __iter__(self) -> Iterator[Union[Tensor, float]]:
for field in dataclasses.fields(self):
yield getattr(self, field.name)

def __repr__(self):
return _frame_repr(self)
68 changes: 68 additions & 0 deletions src/torchcodec/decoders/_audio_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from torch import Tensor

from torchcodec import AudioSamples
from torchcodec.decoders import _core as core
from torchcodec.decoders._decoder_utils import (
create_decoder,
Expand Down Expand Up @@ -37,3 +38,70 @@ def __init__(
) = get_and_validate_stream_metadata(
decoder=self._decoder, stream_index=stream_index, media_type="audio"
)
assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy

def get_samples_played_in_range(
self, start_seconds: float, stop_seconds: Optional[float] = None
) -> AudioSamples:
Copy link
Member Author

@NicolasHug NicolasHug Mar 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like start_seconds should default to the stream's beginning, but that can be done later.

Copy link
Contributor

@scotts scotts Mar 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do that, we should also mirror that in the video API. We could try to play the same trick that range() does, where the semantics of the first parameter changes based on the number of parameters, but I'd (softly) rather not do that.

Copy link
Member Author

@NicolasHug NicolasHug Mar 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I'll update #150 eventually with more of these desirable default behaviors. That would be a good onboarding PR.

"""TODO-AUDIO docs"""
if stop_seconds is not None and not start_seconds <= stop_seconds:
raise ValueError(
f"Invalid start seconds: {start_seconds}. It must be less than or equal to stop seconds ({stop_seconds})."
)
if not self._begin_stream_seconds <= start_seconds < self._end_stream_seconds:
raise ValueError(
f"Invalid start seconds: {start_seconds}. "
f"It must be greater than or equal to {self._begin_stream_seconds} "
f"and less than or equal to {self._end_stream_seconds}."
)
frames, first_pts = core.get_frames_by_pts_in_range_audio(
self._decoder,
start_seconds=start_seconds,
stop_seconds=stop_seconds,
)
first_pts = first_pts.item()

# x = frame boundaries
#
# first_pts last_pts
# v v
# ....x..........x..........x...........x..........x..........x.....
# ^ ^
# start_seconds stop_seconds
#
# We want to return the samples in [start_seconds, stop_seconds). But
# because the core API is based on frames, the `frames` tensor contains
# the samples in [first_pts, last_pts)
# So we do some basic math to figure out the position of the view that
# we'll return.

# TODO: sample_rate is either the original one from metadata, or the
# user-specified one (NIY)
assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy
sample_rate = self.metadata.sample_rate

# TODO: metadata's sample_rate should probably not be Optional
assert sample_rate is not None # mypy.

if first_pts < start_seconds:
offset_beginning = round((start_seconds - first_pts) * sample_rate)
output_pts_seconds = start_seconds
else:
# In normal cases we'll have first_pts <= start_pts, but in some
# edge cases it's possible to have first_pts > start_seconds,
# typically if the stream's first frame's pts isn't exactly 0.
offset_beginning = 0
output_pts_seconds = first_pts

num_samples = frames.shape[1]
last_pts = first_pts + num_samples / self.metadata.sample_rate
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't last_pts be a float here? Is that your intention, or should we call round()?

Copy link
Member Author

@NicolasHug NicolasHug Mar 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it's a float. I didn't add the _seconds everywhere, but they're all in seconds. I can do it if you think it improves clarity.

We do call round() just below, which I think is enough? Or is there an edge-case I'm missing?

if stop_seconds is not None and stop_seconds < last_pts:
offset_end = num_samples - round((last_pts - stop_seconds) * sample_rate)
else:
offset_end = num_samples

return AudioSamples(
data=frames[:, offset_beginning:offset_end],
pts_seconds=output_pts_seconds,
sample_rate=sample_rate,
)
2 changes: 1 addition & 1 deletion src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,7 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(

if (startSeconds == stopSeconds) {
// For consistency with video
return AudioFramesOutput{torch::empty({0}), 0.0};
return AudioFramesOutput{torch::empty({0, 0}), 0.0};
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drive-by, the video APIs return something of shape (0, C, H, W) (where C H W are from the metadata). Here, (0, 0) is the best we can do, at least it preserves the number of dimensions.

}

StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class VideoDecoder {
// DECODING AND SEEKING APIs
// --------------------------------------------------------------------------

// All public decoding entry points return either a FrameOutput or a
// All public video decoding entry points return either a FrameOutput or a
// FrameBatchOutput.
// They are the equivalent of the user-facing Frame and FrameBatch classes in
// Python. They contain RGB decoded frames along with some associated data
Expand Down
123 changes: 123 additions & 0 deletions test/decoders/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,3 +955,126 @@ def test_metadata(self, asset):
)
assert decoder.metadata.sample_rate == asset.sample_rate
assert decoder.metadata.num_channels == asset.num_channels

@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
def test_error(self, asset):
decoder = AudioDecoder(asset.path)

with pytest.raises(ValueError, match="Invalid start seconds"):
decoder.get_samples_played_in_range(start_seconds=-1300)

with pytest.raises(ValueError, match="Invalid start seconds"):
decoder.get_samples_played_in_range(start_seconds=9999)

with pytest.raises(ValueError, match="Invalid start seconds"):
decoder.get_samples_played_in_range(start_seconds=3, stop_seconds=2)

@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
@pytest.mark.parametrize("stop_seconds", (None, "duration", 99999999))
def test_get_all_samples(self, asset, stop_seconds):
decoder = AudioDecoder(asset.path)

if stop_seconds == "duration":
stop_seconds = asset.duration_seconds

samples = decoder.get_samples_played_in_range(
start_seconds=0, stop_seconds=stop_seconds
)

reference_frames = asset.get_frame_data_by_range(
start=0, stop=asset.get_frame_index(pts_seconds=asset.duration_seconds) + 1
)

torch.testing.assert_close(samples.data, reference_frames)
assert samples.sample_rate == asset.sample_rate

# TODO there's a bug with NASA_AUDIO_MP3: https://github.com/pytorch/torchcodec/issues/553
expected_pts = (
0.072
if asset is NASA_AUDIO_MP3
else asset.get_frame_info(idx=0).pts_seconds
)
assert samples.pts_seconds == expected_pts

@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
def test_at_frame_boundaries(self, asset):
decoder = AudioDecoder(asset.path)

start_frame_index, stop_frame_index = 10, 40
start_seconds = asset.get_frame_info(start_frame_index).pts_seconds
stop_seconds = asset.get_frame_info(stop_frame_index).pts_seconds

samples = decoder.get_samples_played_in_range(
start_seconds=start_seconds, stop_seconds=stop_seconds
)

reference_frames = asset.get_frame_data_by_range(
start=start_frame_index, stop=stop_frame_index
)

assert samples.pts_seconds == start_seconds
num_samples = samples.data.shape[1]
assert (
num_samples
== reference_frames.shape[1]
== (stop_seconds - start_seconds) * decoder.metadata.sample_rate
)
torch.testing.assert_close(samples.data, reference_frames)
assert samples.sample_rate == asset.sample_rate

@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
def test_not_at_frame_boundaries(self, asset):
decoder = AudioDecoder(asset.path)

start_frame_index, stop_frame_index = 10, 40
start_frame_info = asset.get_frame_info(start_frame_index)
stop_frame_info = asset.get_frame_info(stop_frame_index)
start_seconds = start_frame_info.pts_seconds + (
start_frame_info.duration_seconds / 2
)
stop_seconds = stop_frame_info.pts_seconds + (
stop_frame_info.duration_seconds / 2
)
samples = decoder.get_samples_played_in_range(
start_seconds=start_seconds, stop_seconds=stop_seconds
)

reference_frames = asset.get_frame_data_by_range(
start=start_frame_index, stop=stop_frame_index + 1
)

assert samples.pts_seconds == start_seconds
num_samples = samples.data.shape[1]
assert num_samples < reference_frames.shape[1]
assert (
num_samples == (stop_seconds - start_seconds) * decoder.metadata.sample_rate
)
assert samples.sample_rate == asset.sample_rate

@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
def test_start_equals_stop(self, asset):
decoder = AudioDecoder(asset.path)
samples = decoder.get_samples_played_in_range(start_seconds=3, stop_seconds=3)
assert samples.data.shape == (0, 0)

def test_frame_start_is_not_zero(self):
# For NASA_AUDIO_MP3, the first frame is not at 0, it's at 0.072 [1].
# So if we request start = 0.05, we shouldn't be truncating anything.
#
# [1] well, really it's at 0.138125, not 0.072 (see
# https://github.com/pytorch/torchcodec/issues/553), but for the purpose
# of this test it doesn't matter.

asset = NASA_AUDIO_MP3
start_seconds = 0.05 # this is less than the first frame's pts
stop_frame_index = 10
stop_seconds = asset.get_frame_info(stop_frame_index).pts_seconds

decoder = AudioDecoder(asset.path)

samples = decoder.get_samples_played_in_range(
start_seconds=start_seconds, stop_seconds=stop_seconds
)

reference_frames = asset.get_frame_data_by_range(start=0, stop=stop_frame_index)
torch.testing.assert_close(samples.data, reference_frames)
2 changes: 1 addition & 1 deletion test/decoders/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ def test_decode_start_equal_stop(self, asset):
frames, pts_seconds = get_frames_by_pts_in_range_audio(
decoder, start_seconds=1, stop_seconds=1
)
assert frames.shape == (0,)
assert frames.shape == (0, 0)
assert pts_seconds == 0

@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
Expand Down
20 changes: 18 additions & 2 deletions test/test_frame_dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import pytest
import torch
from torchcodec import Frame, FrameBatch
from torchcodec import AudioSamples, Frame, FrameBatch


def test_frame_unpacking():
def test_unpacking():
data, pts_seconds, duration_seconds = Frame(torch.rand(3, 4, 5), 2, 3) # noqa
data, pts_seconds, sample_rate = AudioSamples(torch.rand(2, 4), 2, 16_000)


def test_frame_error():
Expand Down Expand Up @@ -139,3 +140,18 @@ def test_framebatch_indexing():
fb_fancy = fb[[[0], [1]]] # select T=0 and N=1.
assert isinstance(fb_fancy, FrameBatch)
assert fb_fancy.data.shape == (1, C, H, W)


def test_audio_samples_error():
with pytest.raises(ValueError, match="data must be 2-dimensional"):
AudioSamples(
data=torch.rand(1),
pts_seconds=1,
sample_rate=16_000,
)
with pytest.raises(ValueError, match="data must be 2-dimensional"):
AudioSamples(
data=torch.rand(1, 2, 3),
pts_seconds=1,
sample_rate=16_000,
)
Loading