diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 7c75b8fb..79bad294 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -170,6 +170,9 @@ void VideoDecoder::initializeDecoder() { } containerMetadata_.numVideoStreams++; } else if (avStream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) { + AVSampleFormat format = + static_cast(avStream->codecpar->format); + streamMetadata.sampleFormat = av_get_sample_fmt_name(format); containerMetadata_.numAudioStreams++; } diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 3ca6380b..51a780fb 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -81,6 +81,7 @@ class VideoDecoder { // Audio-only fields std::optional sampleRate; std::optional numChannels; + std::optional sampleFormat; }; struct ContainerMetadata { diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 7472c3de..fffb1118 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -495,12 +495,15 @@ std::string get_stream_json_metadata( if (streamMetadata.numChannels.has_value()) { map["numChannels"] = std::to_string(*streamMetadata.numChannels); } + if (streamMetadata.sampleFormat.has_value()) { + map["sampleFormat"] = quoteValue(streamMetadata.sampleFormat.value()); + } if (streamMetadata.mediaType == AVMEDIA_TYPE_VIDEO) { - map["mediaType"] = "\"video\""; + map["mediaType"] = quoteValue("video"); } else if (streamMetadata.mediaType == AVMEDIA_TYPE_AUDIO) { - map["mediaType"] = "\"audio\""; + map["mediaType"] = quoteValue("audio"); } else { - map["mediaType"] = "\"other\""; + map["mediaType"] = quoteValue("other"); } return mapToJson(map); } diff --git a/src/torchcodec/decoders/_core/_metadata.py b/src/torchcodec/decoders/_core/_metadata.py index fcfddecc..bf2e0256 100644 --- a/src/torchcodec/decoders/_core/_metadata.py +++ b/src/torchcodec/decoders/_core/_metadata.py @@ -161,9 +161,9 @@ def __repr__(self): class AudioStreamMetadata(StreamMetadata): """Metadata of a single audio stream.""" - # TODO-AUDIO Add sample format field sample_rate: Optional[int] num_channels: Optional[int] + sample_format: Optional[str] def __repr__(self): return super().__repr__() @@ -240,6 +240,7 @@ def get_container_metadata(decoder: torch.Tensor) -> ContainerMetadata: AudioStreamMetadata( sample_rate=stream_dict.get("sampleRate"), num_channels=stream_dict.get("numChannels"), + sample_format=stream_dict.get("sampleFormat"), **common_meta, ) ) diff --git a/test/decoders/test_decoders.py b/test/decoders/test_decoders.py index 6cd01dad..1ddf3252 100644 --- a/test/decoders/test_decoders.py +++ b/test/decoders/test_decoders.py @@ -25,6 +25,7 @@ NASA_AUDIO, NASA_AUDIO_MP3, NASA_VIDEO, + SINE_MONO_S32, ) @@ -940,7 +941,7 @@ def get_some_frames(decoder): class TestAudioDecoder: - @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) + @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3, SINE_MONO_S32)) def test_metadata(self, asset): decoder = AudioDecoder(asset.path) assert isinstance(decoder.metadata, AudioStreamMetadata) @@ -955,6 +956,7 @@ def test_metadata(self, asset): ) assert decoder.metadata.sample_rate == asset.sample_rate assert decoder.metadata.num_channels == asset.num_channels + assert decoder.metadata.sample_format == asset.sample_format @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) def test_error(self, asset): diff --git a/test/decoders/test_metadata.py b/test/decoders/test_metadata.py index 55046bf5..dec98603 100644 --- a/test/decoders/test_metadata.py +++ b/test/decoders/test_metadata.py @@ -90,6 +90,7 @@ def test_get_metadata(metadata_getter): ) assert best_audio_stream_metadata.bit_rate == 128837 assert best_audio_stream_metadata.codec == "aac" + assert best_audio_stream_metadata.sample_format == "fltp" @pytest.mark.parametrize( @@ -109,6 +110,7 @@ def test_get_metadata_audio_file(metadata_getter): ) assert best_audio_stream_metadata.bit_rate == 64000 assert best_audio_stream_metadata.codec == "mp3" + assert best_audio_stream_metadata.sample_format == "fltp" @pytest.mark.parametrize( diff --git a/test/resources/sine_mono_s32.wav b/test/resources/sine_mono_s32.wav new file mode 100644 index 00000000..93182c4e Binary files /dev/null and b/test/resources/sine_mono_s32.wav differ diff --git a/test/resources/sine_mono_s32.wav.stream0.all_frames_info.json b/test/resources/sine_mono_s32.wav.stream0.all_frames_info.json new file mode 100644 index 00000000..224d8fac --- /dev/null +++ b/test/resources/sine_mono_s32.wav.stream0.all_frames_info.json @@ -0,0 +1,254 @@ +[ + { + "duration_time": "0.064000", + "pts_time": "0.000000" + }, + { + "duration_time": "0.064000", + "pts_time": "0.064000" + }, + { + "duration_time": "0.064000", + "pts_time": "0.128000" + }, + { + "duration_time": "0.064000", + "pts_time": "0.192000" + }, + { + "duration_time": "0.064000", + "pts_time": "0.256000" + }, + { + "duration_time": "0.064000", + "pts_time": "0.320000" + }, + { + "duration_time": "0.064000", + "pts_time": "0.384000" + }, + { + "duration_time": "0.064000", + "pts_time": "0.448000" + }, + { + "duration_time": "0.064000", + "pts_time": "0.512000" + }, + { + "duration_time": "0.064000", + "pts_time": "0.576000" + }, + { + "duration_time": "0.064000", + "pts_time": "0.640000" + }, + { + "duration_time": "0.064000", + "pts_time": "0.704000" + }, + { + "duration_time": "0.064000", + "pts_time": "0.768000" + }, + { + "duration_time": "0.064000", + "pts_time": "0.832000" + }, + { + "duration_time": "0.064000", + "pts_time": "0.896000" + }, + { + "duration_time": "0.064000", + "pts_time": "0.960000" + }, + { + "duration_time": "0.064000", + "pts_time": "1.024000" + }, + { + "duration_time": "0.064000", + "pts_time": "1.088000" + }, + { + "duration_time": "0.064000", + "pts_time": "1.152000" + }, + { + "duration_time": "0.064000", + "pts_time": "1.216000" + }, + { + "duration_time": "0.064000", + "pts_time": "1.280000" + }, + { + "duration_time": "0.064000", + "pts_time": "1.344000" + }, + { + "duration_time": "0.064000", + "pts_time": "1.408000" + }, + { + "duration_time": "0.064000", + "pts_time": "1.472000" + }, + { + "duration_time": "0.064000", + "pts_time": "1.536000" + }, + { + "duration_time": "0.064000", + "pts_time": "1.600000" + }, + { + "duration_time": "0.064000", + "pts_time": "1.664000" + }, + { + "duration_time": "0.064000", + "pts_time": "1.728000" + }, + { + "duration_time": "0.064000", + "pts_time": "1.792000" + }, + { + "duration_time": "0.064000", + "pts_time": "1.856000" + }, + { + "duration_time": "0.064000", + "pts_time": "1.920000" + }, + { + "duration_time": "0.064000", + "pts_time": "1.984000" + }, + { + "duration_time": "0.064000", + "pts_time": "2.048000" + }, + { + "duration_time": "0.064000", + "pts_time": "2.112000" + }, + { + "duration_time": "0.064000", + "pts_time": "2.176000" + }, + { + "duration_time": "0.064000", + "pts_time": "2.240000" + }, + { + "duration_time": "0.064000", + "pts_time": "2.304000" + }, + { + "duration_time": "0.064000", + "pts_time": "2.368000" + }, + { + "duration_time": "0.064000", + "pts_time": "2.432000" + }, + { + "duration_time": "0.064000", + "pts_time": "2.496000" + }, + { + "duration_time": "0.064000", + "pts_time": "2.560000" + }, + { + "duration_time": "0.064000", + "pts_time": "2.624000" + }, + { + "duration_time": "0.064000", + "pts_time": "2.688000" + }, + { + "duration_time": "0.064000", + "pts_time": "2.752000" + }, + { + "duration_time": "0.064000", + "pts_time": "2.816000" + }, + { + "duration_time": "0.064000", + "pts_time": "2.880000" + }, + { + "duration_time": "0.064000", + "pts_time": "2.944000" + }, + { + "duration_time": "0.064000", + "pts_time": "3.008000" + }, + { + "duration_time": "0.064000", + "pts_time": "3.072000" + }, + { + "duration_time": "0.064000", + "pts_time": "3.136000" + }, + { + "duration_time": "0.064000", + "pts_time": "3.200000" + }, + { + "duration_time": "0.064000", + "pts_time": "3.264000" + }, + { + "duration_time": "0.064000", + "pts_time": "3.328000" + }, + { + "duration_time": "0.064000", + "pts_time": "3.392000" + }, + { + "duration_time": "0.064000", + "pts_time": "3.456000" + }, + { + "duration_time": "0.064000", + "pts_time": "3.520000" + }, + { + "duration_time": "0.064000", + "pts_time": "3.584000" + }, + { + "duration_time": "0.064000", + "pts_time": "3.648000" + }, + { + "duration_time": "0.064000", + "pts_time": "3.712000" + }, + { + "duration_time": "0.064000", + "pts_time": "3.776000" + }, + { + "duration_time": "0.064000", + "pts_time": "3.840000" + }, + { + "duration_time": "0.064000", + "pts_time": "3.904000" + }, + { + "duration_time": "0.032000", + "pts_time": "3.968000" + } +] diff --git a/test/utils.py b/test/utils.py index c6ce0ec8..c34ef51f 100644 --- a/test/utils.py +++ b/test/utils.py @@ -109,6 +109,7 @@ class TestAudioStreamInfo: num_channels: int duration_seconds: float num_frames: int + sample_format: str @dataclass @@ -340,9 +341,13 @@ def __post_init__(self): f"{self.filename}.stream{stream_index}.all_frames.pt" ) - self._reference_frames[stream_index] = torch.load( - frames_data_path, weights_only=True - ) + if frames_data_path.exists(): + # To ease development, we allow for the reference frames not to + # exist. It means the asset cannot be used to check validity of + # decoded frames. + self._reference_frames[stream_index] = torch.load( + frames_data_path, weights_only=True + ) def get_frame_data_by_index( self, idx: int, *, stream_index: Optional[int] = None @@ -404,6 +409,10 @@ def duration_seconds(self) -> float: def num_frames(self) -> int: return self.stream_infos[self.default_stream_index].num_frames + @property + def sample_format(self) -> str: + return self.stream_infos[self.default_stream_index].sample_format + NASA_AUDIO_MP3 = TestAudio( filename="nasa_13013.mp4.audio.mp3", @@ -411,7 +420,11 @@ def num_frames(self) -> int: frames={}, # Automatically loaded from json file stream_infos={ 0: TestAudioStreamInfo( - sample_rate=8_000, num_channels=2, duration_seconds=13.248, num_frames=183 + sample_rate=8_000, + num_channels=2, + duration_seconds=13.248, + num_frames=183, + sample_format="fltp", ) }, ) @@ -422,7 +435,26 @@ def num_frames(self) -> int: frames={}, # Automatically loaded from json file stream_infos={ 4: TestAudioStreamInfo( - sample_rate=16_000, num_channels=2, duration_seconds=13.056, num_frames=204 + sample_rate=16_000, + num_channels=2, + duration_seconds=13.056, + num_frames=204, + sample_format="fltp", + ) + }, +) + +SINE_MONO_S32 = TestAudio( + filename="sine_mono_s32.wav", + default_stream_index=0, + frames={}, # Automatically loaded from json file + stream_infos={ + 0: TestAudioStreamInfo( + sample_rate=16_000, + num_channels=1, + duration_seconds=4, + num_frames=63, + sample_format="s32", ) }, )