From 4495150f46ae62001d196aa3e931df0f9ca2f32b Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 12 Mar 2025 20:07:22 +0000 Subject: [PATCH 1/6] Add get_samples_played_in_range public method --- src/torchcodec/decoders/_audio_decoder.py | 49 +++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/src/torchcodec/decoders/_audio_decoder.py b/src/torchcodec/decoders/_audio_decoder.py index a99d5d3f..680f1289 100644 --- a/src/torchcodec/decoders/_audio_decoder.py +++ b/src/torchcodec/decoders/_audio_decoder.py @@ -37,3 +37,52 @@ def __init__( ) = get_and_validate_stream_metadata( decoder=self._decoder, stream_index=stream_index, media_type="audio" ) + + def get_samples_played_in_range( + self, start_seconds: float = 0, stop_seconds: Optional[float] = None + ) -> Tensor: + """TODO-AUDIO docs""" + if stop_seconds is not None and not start_seconds <= stop_seconds: + raise ValueError( + f"Invalid start seconds: {start_seconds}. It must be less than or equal to stop seconds ({stop_seconds})." + ) + if not self._begin_stream_seconds <= start_seconds < self._end_stream_seconds: + raise ValueError( + f"Invalid start seconds: {start_seconds}. " + f"It must be greater than or equal to {self._begin_stream_seconds} " + f"and less than or equal to {self._end_stream_seconds}." + ) + frames, first_pts = core.get_frames_by_pts_in_range_audio( + self._decoder, + start_seconds=start_seconds, + stop_seconds=stop_seconds, + ) + first_pts = first_pts.item() + + # x = frame boundaries + # + # first_pts last_pts + # v v + # ....x..........x..........x...........x..........x..........x..........x..... + # ^ ^ + # start_seconds stop_seconds + # + # We want to return the samples in [start_seconds, stop_seconds). But + # because the core API is based on frames, the `frames` tensor contains + # the samples in [first_pts, last_pts).pts + # + # So we return a view on that tensor and do some basic math to figure + # out where to chunk it. + + offset_beginning = round( + (max(0, start_seconds - first_pts)) * self.metadata.sample_rate + ) + + num_samples = frames.shape[1] + offset_end = num_samples + last_pts = first_pts + num_samples / self.metadata.sample_rate + if stop_seconds is not None and stop_seconds < last_pts: + offset_end -= round((last_pts - stop_seconds) * self.metadata.sample_rate) + + return frames[:, offset_beginning:offset_end] + # return frames[:, offset_beginning:offset_end] From 277fac2fe233adac3f77f9fb556b6a65187e4557 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 12 Mar 2025 20:08:16 +0000 Subject: [PATCH 2/6] Nit --- src/torchcodec/decoders/_audio_decoder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/decoders/_audio_decoder.py b/src/torchcodec/decoders/_audio_decoder.py index 680f1289..1edb3b37 100644 --- a/src/torchcodec/decoders/_audio_decoder.py +++ b/src/torchcodec/decoders/_audio_decoder.py @@ -69,10 +69,10 @@ def get_samples_played_in_range( # # We want to return the samples in [start_seconds, stop_seconds). But # because the core API is based on frames, the `frames` tensor contains - # the samples in [first_pts, last_pts).pts + # the samples in [first_pts, last_pts) # - # So we return a view on that tensor and do some basic math to figure - # out where to chunk it. + # So we do some basic math to figure out the position of the view that + # we'l; return. offset_beginning = round( (max(0, start_seconds - first_pts)) * self.metadata.sample_rate From 0f9e14dd500105f355487a74e2cab2ec330b7664 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 13 Mar 2025 10:13:22 +0000 Subject: [PATCH 3/6] WIP --- src/torchcodec/__init__.py | 2 +- src/torchcodec/_frame.py | 24 +++++++++++++- src/torchcodec/decoders/_audio_decoder.py | 34 +++++++++++++------- src/torchcodec/decoders/_core/VideoDecoder.h | 2 +- test/decoders/test_decoders.py | 32 ++++++++++++++++++ test/test_frame_dataclasses.py | 19 +++++++++-- 6 files changed, 97 insertions(+), 16 deletions(-) diff --git a/src/torchcodec/__init__.py b/src/torchcodec/__init__.py index e96aa5b3..2b6c9fb3 100644 --- a/src/torchcodec/__init__.py +++ b/src/torchcodec/__init__.py @@ -6,7 +6,7 @@ # Note: usort wants to put Frame and FrameBatch after decoders and samplers, # but that results in circular import. -from ._frame import Frame, FrameBatch # usort:skip # noqa +from ._frame import AudioSamples, Frame, FrameBatch # usort:skip # noqa from . import decoders, samplers # noqa try: diff --git a/src/torchcodec/_frame.py b/src/torchcodec/_frame.py index a8fc7f5b..958373ae 100644 --- a/src/torchcodec/_frame.py +++ b/src/torchcodec/_frame.py @@ -12,7 +12,7 @@ def _frame_repr(self): - # Utility to replace Frame and FrameBatch __repr__ method. This prints the + # Utility to replace __repr__ method of dataclasses below. This prints the # shape of the .data tensor rather than printing the (potentially very long) # data tensor itself. s = self.__class__.__name__ + ":\n" @@ -114,3 +114,25 @@ def __len__(self): def __repr__(self): return _frame_repr(self) + +@dataclass +class AudioSamples(Iterable): + """Audio samples with associated metadata.""" + # TODO-AUDIO: docs + data: Tensor + pts_seconds: float + sample_rate: int + def __post_init__(self): + # This is called after __init__() when a Frame is created. We can run + # input validation checks here. + if not self.data.ndim == 2: + raise ValueError(f"data must be 2-dimensional, got {self.data.shape = }") + self.pts_seconds = float(self.pts_seconds) + self.sample_rate = int(self.sample_rate) + + def __iter__(self) -> Iterator[Union[Tensor, float]]: + for field in dataclasses.fields(self): + yield getattr(self, field.name) + + def __repr__(self): + return _frame_repr(self) \ No newline at end of file diff --git a/src/torchcodec/decoders/_audio_decoder.py b/src/torchcodec/decoders/_audio_decoder.py index 1edb3b37..e510d6df 100644 --- a/src/torchcodec/decoders/_audio_decoder.py +++ b/src/torchcodec/decoders/_audio_decoder.py @@ -9,6 +9,7 @@ from torch import Tensor +from torchcodec import AudioSamples from torchcodec.decoders import _core as core from torchcodec.decoders._decoder_utils import ( create_decoder, @@ -39,7 +40,7 @@ def __init__( ) def get_samples_played_in_range( - self, start_seconds: float = 0, stop_seconds: Optional[float] = None + self, start_seconds: float, stop_seconds: Optional[float] = None ) -> Tensor: """TODO-AUDIO docs""" if stop_seconds is not None and not start_seconds <= stop_seconds: @@ -63,26 +64,37 @@ def get_samples_played_in_range( # # first_pts last_pts # v v - # ....x..........x..........x...........x..........x..........x..........x..... + # ....x..........x..........x...........x..........x..........x..... # ^ ^ # start_seconds stop_seconds # # We want to return the samples in [start_seconds, stop_seconds). But # because the core API is based on frames, the `frames` tensor contains # the samples in [first_pts, last_pts) - # # So we do some basic math to figure out the position of the view that - # we'l; return. + # we'll return. - offset_beginning = round( - (max(0, start_seconds - first_pts)) * self.metadata.sample_rate - ) + # TODO: sample_rate is either the original one from metadata, or the + # user-specified one (NIY) + sample_rate = self.metadata.sample_rate + + if first_pts < start_seconds: + offset_beginning = round((start_seconds - first_pts) * sample_rate) + output_pts_seconds = start_seconds + else: + offset_beginning = 0 + output_pts_seconds = first_pts num_samples = frames.shape[1] - offset_end = num_samples last_pts = first_pts + num_samples / self.metadata.sample_rate if stop_seconds is not None and stop_seconds < last_pts: - offset_end -= round((last_pts - stop_seconds) * self.metadata.sample_rate) + offset_end = num_samples - round((last_pts - stop_seconds) * sample_rate) + else: + offset_end = num_samples + + return AudioSamples( + data=frames[:, offset_beginning:offset_end], + pts_seconds=output_pts_seconds, + sample_rate=sample_rate, + ) - return frames[:, offset_beginning:offset_end] - # return frames[:, offset_beginning:offset_end] diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 60c10055..3ca6380b 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -147,7 +147,7 @@ class VideoDecoder { // DECODING AND SEEKING APIs // -------------------------------------------------------------------------- - // All public decoding entry points return either a FrameOutput or a + // All public video decoding entry points return either a FrameOutput or a // FrameBatchOutput. // They are the equivalent of the user-facing Frame and FrameBatch classes in // Python. They contain RGB decoded frames along with some associated data diff --git a/test/decoders/test_decoders.py b/test/decoders/test_decoders.py index e0339120..4158d827 100644 --- a/test/decoders/test_decoders.py +++ b/test/decoders/test_decoders.py @@ -955,3 +955,35 @@ def test_metadata(self, asset): ) assert decoder.metadata.sample_rate == asset.sample_rate assert decoder.metadata.num_channels == asset.num_channels + + @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) + def test_get_all_samples(self, asset): + decoder = AudioDecoder(asset.path) + + samples = decoder.get_samples_played_in_range(start_seconds=0, stop_seconds=None) + + reference_frames = asset.get_frame_data_by_range( + start=0, + stop=asset.get_frame_index(pts_seconds=asset.duration_seconds) + 1 + ) + + torch.testing.assert_close(samples.data, reference_frames) + assert samples.pts_seconds == asset.get_frame_info(idx=0).pts_seconds + + @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) + def test_get_samples_played_in_range(self, asset): + decoder = AudioDecoder(asset.path) + + start_seconds, stop_seconds = 2, 4 + samples = decoder.get_samples_played_in_range(start_seconds=start_seconds, stop_seconds=stop_seconds) + + reference_frames = asset.get_frame_data_by_range( + start=asset.get_frame_index(pts_seconds=start_seconds), + stop=asset.get_frame_index(pts_seconds=stop_seconds) + 1 + ) + + assert samples.pts_seconds == start_seconds + num_samples = samples.data.shape[1] + assert num_samples < reference_frames.shape[1] + assert num_samples == (stop_seconds - start_seconds) * decoder.metadata.sample_rate + diff --git a/test/test_frame_dataclasses.py b/test/test_frame_dataclasses.py index a9840f22..2a501b00 100644 --- a/test/test_frame_dataclasses.py +++ b/test/test_frame_dataclasses.py @@ -1,10 +1,11 @@ import pytest import torch -from torchcodec import Frame, FrameBatch +from torchcodec import Frame, FrameBatch, AudioSamples -def test_frame_unpacking(): +def test_unpacking(): data, pts_seconds, duration_seconds = Frame(torch.rand(3, 4, 5), 2, 3) # noqa + data, pts_seconds, sample_rate = AudioSamples(torch.rand(2, 4), 2, 16_000) def test_frame_error(): @@ -139,3 +140,17 @@ def test_framebatch_indexing(): fb_fancy = fb[[[0], [1]]] # select T=0 and N=1. assert isinstance(fb_fancy, FrameBatch) assert fb_fancy.data.shape == (1, C, H, W) + +def test_audio_samples_error(): + with pytest.raises(ValueError, match="data must be 2-dimensional"): + AudioSamples( + data=torch.rand(1), + pts_seconds=1, + sample_rate=16_000, + ) + with pytest.raises(ValueError, match="data must be 2-dimensional"): + AudioSamples( + data=torch.rand(1, 2, 3), + pts_seconds=1, + sample_rate=16_000, + ) \ No newline at end of file From 00bb28d23f6615fe4e45282adbcacf87ffb13845 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 13 Mar 2025 10:51:07 +0000 Subject: [PATCH 4/6] Add pts test for audio --- test/decoders/test_ops.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/test/decoders/test_ops.py b/test/decoders/test_ops.py index b510c8ef..01d54ea9 100644 --- a/test/decoders/test_ops.py +++ b/test/decoders/test_ops.py @@ -824,6 +824,32 @@ def get_reference_frames(start_seconds, stop_seconds): frames, get_reference_frames(start_seconds, stop_seconds) ) + @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) + def test_pts(self, asset): + decoder = create_from_file(str(asset.path), seek_mode="approximate") + add_audio_stream(decoder) + + for frame_index in range(asset.num_frames): + frame_info = asset.get_frame_info(idx=frame_index) + start_seconds = frame_info.pts_seconds + + frames, pts_seconds = get_frames_by_pts_in_range_audio( + decoder, start_seconds=start_seconds, stop_seconds=start_seconds + 1e-3 + ) + torch.testing.assert_close( + frames, asset.get_frame_data_by_index(frame_index) + ) + + if asset is NASA_AUDIO_MP3 and frame_index == 0: + # TODO This is a bug. The 0.138125 is correct while 0.072 is + # incorrect, even though it comes from the decoded AVFrame's pts + # field. + # See https://github.com/pytorch/torchcodec/issues/553 + assert pts_seconds == 0.072 + assert start_seconds == 0.138125 + else: + assert pts_seconds == start_seconds + if __name__ == "__main__": pytest.main() From a7b67d54addea9f686ebfe3813d10d6cfcf78431 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 13 Mar 2025 11:55:21 +0000 Subject: [PATCH 5/6] WIP --- src/torchcodec/_frame.py | 5 +- src/torchcodec/decoders/_audio_decoder.py | 4 +- .../decoders/_core/VideoDecoder.cpp | 2 +- test/decoders/test_decoders.py | 117 ++++++++++++++++-- test/decoders/test_ops.py | 2 +- test/test_frame_dataclasses.py | 5 +- 6 files changed, 116 insertions(+), 19 deletions(-) diff --git a/src/torchcodec/_frame.py b/src/torchcodec/_frame.py index 958373ae..31a9b666 100644 --- a/src/torchcodec/_frame.py +++ b/src/torchcodec/_frame.py @@ -115,13 +115,16 @@ def __len__(self): def __repr__(self): return _frame_repr(self) + @dataclass class AudioSamples(Iterable): """Audio samples with associated metadata.""" + # TODO-AUDIO: docs data: Tensor pts_seconds: float sample_rate: int + def __post_init__(self): # This is called after __init__() when a Frame is created. We can run # input validation checks here. @@ -135,4 +138,4 @@ def __iter__(self) -> Iterator[Union[Tensor, float]]: yield getattr(self, field.name) def __repr__(self): - return _frame_repr(self) \ No newline at end of file + return _frame_repr(self) diff --git a/src/torchcodec/decoders/_audio_decoder.py b/src/torchcodec/decoders/_audio_decoder.py index e510d6df..a63ea7da 100644 --- a/src/torchcodec/decoders/_audio_decoder.py +++ b/src/torchcodec/decoders/_audio_decoder.py @@ -82,6 +82,9 @@ def get_samples_played_in_range( offset_beginning = round((start_seconds - first_pts) * sample_rate) output_pts_seconds = start_seconds else: + # In normal cases we'll have first_pts <= start_pts, but in some + # edge cases it's possible to have first_pts > start_seconds, + # typically if the stream's first frame's pts isn't exactly 0. offset_beginning = 0 output_pts_seconds = first_pts @@ -97,4 +100,3 @@ def get_samples_played_in_range( pts_seconds=output_pts_seconds, sample_rate=sample_rate, ) - diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 0e287a5b..7c75b8fb 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -854,7 +854,7 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio( if (startSeconds == stopSeconds) { // For consistency with video - return AudioFramesOutput{torch::empty({0}), 0.0}; + return AudioFramesOutput{torch::empty({0, 0}), 0.0}; } StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; diff --git a/test/decoders/test_decoders.py b/test/decoders/test_decoders.py index 4158d827..6cd01dad 100644 --- a/test/decoders/test_decoders.py +++ b/test/decoders/test_decoders.py @@ -957,33 +957,124 @@ def test_metadata(self, asset): assert decoder.metadata.num_channels == asset.num_channels @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) - def test_get_all_samples(self, asset): + def test_error(self, asset): decoder = AudioDecoder(asset.path) - - samples = decoder.get_samples_played_in_range(start_seconds=0, stop_seconds=None) + + with pytest.raises(ValueError, match="Invalid start seconds"): + decoder.get_samples_played_in_range(start_seconds=-1300) + + with pytest.raises(ValueError, match="Invalid start seconds"): + decoder.get_samples_played_in_range(start_seconds=9999) + + with pytest.raises(ValueError, match="Invalid start seconds"): + decoder.get_samples_played_in_range(start_seconds=3, stop_seconds=2) + + @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) + @pytest.mark.parametrize("stop_seconds", (None, "duration", 99999999)) + def test_get_all_samples(self, asset, stop_seconds): + decoder = AudioDecoder(asset.path) + + if stop_seconds == "duration": + stop_seconds = asset.duration_seconds + + samples = decoder.get_samples_played_in_range( + start_seconds=0, stop_seconds=stop_seconds + ) reference_frames = asset.get_frame_data_by_range( - start=0, - stop=asset.get_frame_index(pts_seconds=asset.duration_seconds) + 1 + start=0, stop=asset.get_frame_index(pts_seconds=asset.duration_seconds) + 1 ) torch.testing.assert_close(samples.data, reference_frames) - assert samples.pts_seconds == asset.get_frame_info(idx=0).pts_seconds + assert samples.sample_rate == asset.sample_rate + + # TODO there's a bug with NASA_AUDIO_MP3: https://github.com/pytorch/torchcodec/issues/553 + expected_pts = ( + 0.072 + if asset is NASA_AUDIO_MP3 + else asset.get_frame_info(idx=0).pts_seconds + ) + assert samples.pts_seconds == expected_pts @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) - def test_get_samples_played_in_range(self, asset): + def test_at_frame_boundaries(self, asset): decoder = AudioDecoder(asset.path) - - start_seconds, stop_seconds = 2, 4 - samples = decoder.get_samples_played_in_range(start_seconds=start_seconds, stop_seconds=stop_seconds) + + start_frame_index, stop_frame_index = 10, 40 + start_seconds = asset.get_frame_info(start_frame_index).pts_seconds + stop_seconds = asset.get_frame_info(stop_frame_index).pts_seconds + + samples = decoder.get_samples_played_in_range( + start_seconds=start_seconds, stop_seconds=stop_seconds + ) reference_frames = asset.get_frame_data_by_range( - start=asset.get_frame_index(pts_seconds=start_seconds), - stop=asset.get_frame_index(pts_seconds=stop_seconds) + 1 + start=start_frame_index, stop=stop_frame_index + ) + + assert samples.pts_seconds == start_seconds + num_samples = samples.data.shape[1] + assert ( + num_samples + == reference_frames.shape[1] + == (stop_seconds - start_seconds) * decoder.metadata.sample_rate + ) + torch.testing.assert_close(samples.data, reference_frames) + assert samples.sample_rate == asset.sample_rate + + @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) + def test_not_at_frame_boundaries(self, asset): + decoder = AudioDecoder(asset.path) + + start_frame_index, stop_frame_index = 10, 40 + start_frame_info = asset.get_frame_info(start_frame_index) + stop_frame_info = asset.get_frame_info(stop_frame_index) + start_seconds = start_frame_info.pts_seconds + ( + start_frame_info.duration_seconds / 2 + ) + stop_seconds = stop_frame_info.pts_seconds + ( + stop_frame_info.duration_seconds / 2 + ) + samples = decoder.get_samples_played_in_range( + start_seconds=start_seconds, stop_seconds=stop_seconds + ) + + reference_frames = asset.get_frame_data_by_range( + start=start_frame_index, stop=stop_frame_index + 1 ) assert samples.pts_seconds == start_seconds num_samples = samples.data.shape[1] assert num_samples < reference_frames.shape[1] - assert num_samples == (stop_seconds - start_seconds) * decoder.metadata.sample_rate + assert ( + num_samples == (stop_seconds - start_seconds) * decoder.metadata.sample_rate + ) + assert samples.sample_rate == asset.sample_rate + + @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) + def test_start_equals_stop(self, asset): + decoder = AudioDecoder(asset.path) + samples = decoder.get_samples_played_in_range(start_seconds=3, stop_seconds=3) + assert samples.data.shape == (0, 0) + + def test_frame_start_is_not_zero(self): + # For NASA_AUDIO_MP3, the first frame is not at 0, it's at 0.072 [1]. + # So if we request start = 0.05, we shouldn't be truncating anything. + # + # [1] well, really it's at 0.138125, not 0.072 (see + # https://github.com/pytorch/torchcodec/issues/553), but for the purpose + # of this test it doesn't matter. + + asset = NASA_AUDIO_MP3 + start_seconds = 0.05 # this is less than the first frame's pts + stop_frame_index = 10 + stop_seconds = asset.get_frame_info(stop_frame_index).pts_seconds + decoder = AudioDecoder(asset.path) + + samples = decoder.get_samples_played_in_range( + start_seconds=start_seconds, stop_seconds=stop_seconds + ) + + reference_frames = asset.get_frame_data_by_range(start=0, stop=stop_frame_index) + torch.testing.assert_close(samples.data, reference_frames) diff --git a/test/decoders/test_ops.py b/test/decoders/test_ops.py index 01d54ea9..724eff62 100644 --- a/test/decoders/test_ops.py +++ b/test/decoders/test_ops.py @@ -742,7 +742,7 @@ def test_decode_start_equal_stop(self, asset): frames, pts_seconds = get_frames_by_pts_in_range_audio( decoder, start_seconds=1, stop_seconds=1 ) - assert frames.shape == (0,) + assert frames.shape == (0, 0) assert pts_seconds == 0 @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) diff --git a/test/test_frame_dataclasses.py b/test/test_frame_dataclasses.py index 2a501b00..e014cd4d 100644 --- a/test/test_frame_dataclasses.py +++ b/test/test_frame_dataclasses.py @@ -1,6 +1,6 @@ import pytest import torch -from torchcodec import Frame, FrameBatch, AudioSamples +from torchcodec import AudioSamples, Frame, FrameBatch def test_unpacking(): @@ -141,6 +141,7 @@ def test_framebatch_indexing(): assert isinstance(fb_fancy, FrameBatch) assert fb_fancy.data.shape == (1, C, H, W) + def test_audio_samples_error(): with pytest.raises(ValueError, match="data must be 2-dimensional"): AudioSamples( @@ -153,4 +154,4 @@ def test_audio_samples_error(): data=torch.rand(1, 2, 3), pts_seconds=1, sample_rate=16_000, - ) \ No newline at end of file + ) From 02067cf28fa916ce0ec38905135f800796440910 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 13 Mar 2025 12:40:40 +0000 Subject: [PATCH 6/6] Fix Fing mypy --- src/torchcodec/decoders/_audio_decoder.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_audio_decoder.py b/src/torchcodec/decoders/_audio_decoder.py index a63ea7da..34292751 100644 --- a/src/torchcodec/decoders/_audio_decoder.py +++ b/src/torchcodec/decoders/_audio_decoder.py @@ -38,10 +38,11 @@ def __init__( ) = get_and_validate_stream_metadata( decoder=self._decoder, stream_index=stream_index, media_type="audio" ) + assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy def get_samples_played_in_range( self, start_seconds: float, stop_seconds: Optional[float] = None - ) -> Tensor: + ) -> AudioSamples: """TODO-AUDIO docs""" if stop_seconds is not None and not start_seconds <= stop_seconds: raise ValueError( @@ -76,8 +77,12 @@ def get_samples_played_in_range( # TODO: sample_rate is either the original one from metadata, or the # user-specified one (NIY) + assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy sample_rate = self.metadata.sample_rate + # TODO: metadata's sample_rate should probably not be Optional + assert sample_rate is not None # mypy. + if first_pts < start_seconds: offset_beginning = round((start_seconds - first_pts) * sample_rate) output_pts_seconds = start_seconds