diff --git a/src/torchcodec/__init__.py b/src/torchcodec/__init__.py index e96aa5b3..2b6c9fb3 100644 --- a/src/torchcodec/__init__.py +++ b/src/torchcodec/__init__.py @@ -6,7 +6,7 @@ # Note: usort wants to put Frame and FrameBatch after decoders and samplers, # but that results in circular import. -from ._frame import Frame, FrameBatch # usort:skip # noqa +from ._frame import AudioSamples, Frame, FrameBatch # usort:skip # noqa from . import decoders, samplers # noqa try: diff --git a/src/torchcodec/_frame.py b/src/torchcodec/_frame.py index a8fc7f5b..31a9b666 100644 --- a/src/torchcodec/_frame.py +++ b/src/torchcodec/_frame.py @@ -12,7 +12,7 @@ def _frame_repr(self): - # Utility to replace Frame and FrameBatch __repr__ method. This prints the + # Utility to replace __repr__ method of dataclasses below. This prints the # shape of the .data tensor rather than printing the (potentially very long) # data tensor itself. s = self.__class__.__name__ + ":\n" @@ -114,3 +114,28 @@ def __len__(self): def __repr__(self): return _frame_repr(self) + + +@dataclass +class AudioSamples(Iterable): + """Audio samples with associated metadata.""" + + # TODO-AUDIO: docs + data: Tensor + pts_seconds: float + sample_rate: int + + def __post_init__(self): + # This is called after __init__() when a Frame is created. We can run + # input validation checks here. + if not self.data.ndim == 2: + raise ValueError(f"data must be 2-dimensional, got {self.data.shape = }") + self.pts_seconds = float(self.pts_seconds) + self.sample_rate = int(self.sample_rate) + + def __iter__(self) -> Iterator[Union[Tensor, float]]: + for field in dataclasses.fields(self): + yield getattr(self, field.name) + + def __repr__(self): + return _frame_repr(self) diff --git a/src/torchcodec/decoders/_audio_decoder.py b/src/torchcodec/decoders/_audio_decoder.py index a99d5d3f..34292751 100644 --- a/src/torchcodec/decoders/_audio_decoder.py +++ b/src/torchcodec/decoders/_audio_decoder.py @@ -9,6 +9,7 @@ from torch import Tensor +from torchcodec import AudioSamples from torchcodec.decoders import _core as core from torchcodec.decoders._decoder_utils import ( create_decoder, @@ -37,3 +38,70 @@ def __init__( ) = get_and_validate_stream_metadata( decoder=self._decoder, stream_index=stream_index, media_type="audio" ) + assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy + + def get_samples_played_in_range( + self, start_seconds: float, stop_seconds: Optional[float] = None + ) -> AudioSamples: + """TODO-AUDIO docs""" + if stop_seconds is not None and not start_seconds <= stop_seconds: + raise ValueError( + f"Invalid start seconds: {start_seconds}. It must be less than or equal to stop seconds ({stop_seconds})." + ) + if not self._begin_stream_seconds <= start_seconds < self._end_stream_seconds: + raise ValueError( + f"Invalid start seconds: {start_seconds}. " + f"It must be greater than or equal to {self._begin_stream_seconds} " + f"and less than or equal to {self._end_stream_seconds}." + ) + frames, first_pts = core.get_frames_by_pts_in_range_audio( + self._decoder, + start_seconds=start_seconds, + stop_seconds=stop_seconds, + ) + first_pts = first_pts.item() + + # x = frame boundaries + # + # first_pts last_pts + # v v + # ....x..........x..........x...........x..........x..........x..... + # ^ ^ + # start_seconds stop_seconds + # + # We want to return the samples in [start_seconds, stop_seconds). But + # because the core API is based on frames, the `frames` tensor contains + # the samples in [first_pts, last_pts) + # So we do some basic math to figure out the position of the view that + # we'll return. + + # TODO: sample_rate is either the original one from metadata, or the + # user-specified one (NIY) + assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy + sample_rate = self.metadata.sample_rate + + # TODO: metadata's sample_rate should probably not be Optional + assert sample_rate is not None # mypy. + + if first_pts < start_seconds: + offset_beginning = round((start_seconds - first_pts) * sample_rate) + output_pts_seconds = start_seconds + else: + # In normal cases we'll have first_pts <= start_pts, but in some + # edge cases it's possible to have first_pts > start_seconds, + # typically if the stream's first frame's pts isn't exactly 0. + offset_beginning = 0 + output_pts_seconds = first_pts + + num_samples = frames.shape[1] + last_pts = first_pts + num_samples / self.metadata.sample_rate + if stop_seconds is not None and stop_seconds < last_pts: + offset_end = num_samples - round((last_pts - stop_seconds) * sample_rate) + else: + offset_end = num_samples + + return AudioSamples( + data=frames[:, offset_beginning:offset_end], + pts_seconds=output_pts_seconds, + sample_rate=sample_rate, + ) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 0e287a5b..7c75b8fb 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -854,7 +854,7 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio( if (startSeconds == stopSeconds) { // For consistency with video - return AudioFramesOutput{torch::empty({0}), 0.0}; + return AudioFramesOutput{torch::empty({0, 0}), 0.0}; } StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 60c10055..3ca6380b 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -147,7 +147,7 @@ class VideoDecoder { // DECODING AND SEEKING APIs // -------------------------------------------------------------------------- - // All public decoding entry points return either a FrameOutput or a + // All public video decoding entry points return either a FrameOutput or a // FrameBatchOutput. // They are the equivalent of the user-facing Frame and FrameBatch classes in // Python. They contain RGB decoded frames along with some associated data diff --git a/test/decoders/test_decoders.py b/test/decoders/test_decoders.py index e0339120..6cd01dad 100644 --- a/test/decoders/test_decoders.py +++ b/test/decoders/test_decoders.py @@ -955,3 +955,126 @@ def test_metadata(self, asset): ) assert decoder.metadata.sample_rate == asset.sample_rate assert decoder.metadata.num_channels == asset.num_channels + + @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) + def test_error(self, asset): + decoder = AudioDecoder(asset.path) + + with pytest.raises(ValueError, match="Invalid start seconds"): + decoder.get_samples_played_in_range(start_seconds=-1300) + + with pytest.raises(ValueError, match="Invalid start seconds"): + decoder.get_samples_played_in_range(start_seconds=9999) + + with pytest.raises(ValueError, match="Invalid start seconds"): + decoder.get_samples_played_in_range(start_seconds=3, stop_seconds=2) + + @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) + @pytest.mark.parametrize("stop_seconds", (None, "duration", 99999999)) + def test_get_all_samples(self, asset, stop_seconds): + decoder = AudioDecoder(asset.path) + + if stop_seconds == "duration": + stop_seconds = asset.duration_seconds + + samples = decoder.get_samples_played_in_range( + start_seconds=0, stop_seconds=stop_seconds + ) + + reference_frames = asset.get_frame_data_by_range( + start=0, stop=asset.get_frame_index(pts_seconds=asset.duration_seconds) + 1 + ) + + torch.testing.assert_close(samples.data, reference_frames) + assert samples.sample_rate == asset.sample_rate + + # TODO there's a bug with NASA_AUDIO_MP3: https://github.com/pytorch/torchcodec/issues/553 + expected_pts = ( + 0.072 + if asset is NASA_AUDIO_MP3 + else asset.get_frame_info(idx=0).pts_seconds + ) + assert samples.pts_seconds == expected_pts + + @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) + def test_at_frame_boundaries(self, asset): + decoder = AudioDecoder(asset.path) + + start_frame_index, stop_frame_index = 10, 40 + start_seconds = asset.get_frame_info(start_frame_index).pts_seconds + stop_seconds = asset.get_frame_info(stop_frame_index).pts_seconds + + samples = decoder.get_samples_played_in_range( + start_seconds=start_seconds, stop_seconds=stop_seconds + ) + + reference_frames = asset.get_frame_data_by_range( + start=start_frame_index, stop=stop_frame_index + ) + + assert samples.pts_seconds == start_seconds + num_samples = samples.data.shape[1] + assert ( + num_samples + == reference_frames.shape[1] + == (stop_seconds - start_seconds) * decoder.metadata.sample_rate + ) + torch.testing.assert_close(samples.data, reference_frames) + assert samples.sample_rate == asset.sample_rate + + @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) + def test_not_at_frame_boundaries(self, asset): + decoder = AudioDecoder(asset.path) + + start_frame_index, stop_frame_index = 10, 40 + start_frame_info = asset.get_frame_info(start_frame_index) + stop_frame_info = asset.get_frame_info(stop_frame_index) + start_seconds = start_frame_info.pts_seconds + ( + start_frame_info.duration_seconds / 2 + ) + stop_seconds = stop_frame_info.pts_seconds + ( + stop_frame_info.duration_seconds / 2 + ) + samples = decoder.get_samples_played_in_range( + start_seconds=start_seconds, stop_seconds=stop_seconds + ) + + reference_frames = asset.get_frame_data_by_range( + start=start_frame_index, stop=stop_frame_index + 1 + ) + + assert samples.pts_seconds == start_seconds + num_samples = samples.data.shape[1] + assert num_samples < reference_frames.shape[1] + assert ( + num_samples == (stop_seconds - start_seconds) * decoder.metadata.sample_rate + ) + assert samples.sample_rate == asset.sample_rate + + @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) + def test_start_equals_stop(self, asset): + decoder = AudioDecoder(asset.path) + samples = decoder.get_samples_played_in_range(start_seconds=3, stop_seconds=3) + assert samples.data.shape == (0, 0) + + def test_frame_start_is_not_zero(self): + # For NASA_AUDIO_MP3, the first frame is not at 0, it's at 0.072 [1]. + # So if we request start = 0.05, we shouldn't be truncating anything. + # + # [1] well, really it's at 0.138125, not 0.072 (see + # https://github.com/pytorch/torchcodec/issues/553), but for the purpose + # of this test it doesn't matter. + + asset = NASA_AUDIO_MP3 + start_seconds = 0.05 # this is less than the first frame's pts + stop_frame_index = 10 + stop_seconds = asset.get_frame_info(stop_frame_index).pts_seconds + + decoder = AudioDecoder(asset.path) + + samples = decoder.get_samples_played_in_range( + start_seconds=start_seconds, stop_seconds=stop_seconds + ) + + reference_frames = asset.get_frame_data_by_range(start=0, stop=stop_frame_index) + torch.testing.assert_close(samples.data, reference_frames) diff --git a/test/decoders/test_ops.py b/test/decoders/test_ops.py index 01d54ea9..724eff62 100644 --- a/test/decoders/test_ops.py +++ b/test/decoders/test_ops.py @@ -742,7 +742,7 @@ def test_decode_start_equal_stop(self, asset): frames, pts_seconds = get_frames_by_pts_in_range_audio( decoder, start_seconds=1, stop_seconds=1 ) - assert frames.shape == (0,) + assert frames.shape == (0, 0) assert pts_seconds == 0 @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) diff --git a/test/test_frame_dataclasses.py b/test/test_frame_dataclasses.py index a9840f22..e014cd4d 100644 --- a/test/test_frame_dataclasses.py +++ b/test/test_frame_dataclasses.py @@ -1,10 +1,11 @@ import pytest import torch -from torchcodec import Frame, FrameBatch +from torchcodec import AudioSamples, Frame, FrameBatch -def test_frame_unpacking(): +def test_unpacking(): data, pts_seconds, duration_seconds = Frame(torch.rand(3, 4, 5), 2, 3) # noqa + data, pts_seconds, sample_rate = AudioSamples(torch.rand(2, 4), 2, 16_000) def test_frame_error(): @@ -139,3 +140,18 @@ def test_framebatch_indexing(): fb_fancy = fb[[[0], [1]]] # select T=0 and N=1. assert isinstance(fb_fancy, FrameBatch) assert fb_fancy.data.shape == (1, C, H, W) + + +def test_audio_samples_error(): + with pytest.raises(ValueError, match="data must be 2-dimensional"): + AudioSamples( + data=torch.rand(1), + pts_seconds=1, + sample_rate=16_000, + ) + with pytest.raises(ValueError, match="data must be 2-dimensional"): + AudioSamples( + data=torch.rand(1, 2, 3), + pts_seconds=1, + sample_rate=16_000, + )