Skip to content

Commit 8e611bb

Browse files
authored
Add sample_format to audio metadata (#557)
1 parent f5868be commit 8e611bb

File tree

9 files changed

+308
-10
lines changed

9 files changed

+308
-10
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,9 @@ void VideoDecoder::initializeDecoder() {
170170
}
171171
containerMetadata_.numVideoStreams++;
172172
} else if (avStream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) {
173+
AVSampleFormat format =
174+
static_cast<AVSampleFormat>(avStream->codecpar->format);
175+
streamMetadata.sampleFormat = av_get_sample_fmt_name(format);
173176
containerMetadata_.numAudioStreams++;
174177
}
175178

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class VideoDecoder {
8181
// Audio-only fields
8282
std::optional<int64_t> sampleRate;
8383
std::optional<int64_t> numChannels;
84+
std::optional<std::string> sampleFormat;
8485
};
8586

8687
struct ContainerMetadata {

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -495,12 +495,15 @@ std::string get_stream_json_metadata(
495495
if (streamMetadata.numChannels.has_value()) {
496496
map["numChannels"] = std::to_string(*streamMetadata.numChannels);
497497
}
498+
if (streamMetadata.sampleFormat.has_value()) {
499+
map["sampleFormat"] = quoteValue(streamMetadata.sampleFormat.value());
500+
}
498501
if (streamMetadata.mediaType == AVMEDIA_TYPE_VIDEO) {
499-
map["mediaType"] = "\"video\"";
502+
map["mediaType"] = quoteValue("video");
500503
} else if (streamMetadata.mediaType == AVMEDIA_TYPE_AUDIO) {
501-
map["mediaType"] = "\"audio\"";
504+
map["mediaType"] = quoteValue("audio");
502505
} else {
503-
map["mediaType"] = "\"other\"";
506+
map["mediaType"] = quoteValue("other");
504507
}
505508
return mapToJson(map);
506509
}

src/torchcodec/decoders/_core/_metadata.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,9 @@ def __repr__(self):
161161
class AudioStreamMetadata(StreamMetadata):
162162
"""Metadata of a single audio stream."""
163163

164-
# TODO-AUDIO Add sample format field
165164
sample_rate: Optional[int]
166165
num_channels: Optional[int]
166+
sample_format: Optional[str]
167167

168168
def __repr__(self):
169169
return super().__repr__()
@@ -240,6 +240,7 @@ def get_container_metadata(decoder: torch.Tensor) -> ContainerMetadata:
240240
AudioStreamMetadata(
241241
sample_rate=stream_dict.get("sampleRate"),
242242
num_channels=stream_dict.get("numChannels"),
243+
sample_format=stream_dict.get("sampleFormat"),
243244
**common_meta,
244245
)
245246
)

test/decoders/test_decoders.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
NASA_AUDIO,
2626
NASA_AUDIO_MP3,
2727
NASA_VIDEO,
28+
SINE_MONO_S32,
2829
)
2930

3031

@@ -940,7 +941,7 @@ def get_some_frames(decoder):
940941

941942

942943
class TestAudioDecoder:
943-
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
944+
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3, SINE_MONO_S32))
944945
def test_metadata(self, asset):
945946
decoder = AudioDecoder(asset.path)
946947
assert isinstance(decoder.metadata, AudioStreamMetadata)
@@ -955,6 +956,7 @@ def test_metadata(self, asset):
955956
)
956957
assert decoder.metadata.sample_rate == asset.sample_rate
957958
assert decoder.metadata.num_channels == asset.num_channels
959+
assert decoder.metadata.sample_format == asset.sample_format
958960

959961
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
960962
def test_error(self, asset):

test/decoders/test_metadata.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def test_get_metadata(metadata_getter):
9090
)
9191
assert best_audio_stream_metadata.bit_rate == 128837
9292
assert best_audio_stream_metadata.codec == "aac"
93+
assert best_audio_stream_metadata.sample_format == "fltp"
9394

9495

9596
@pytest.mark.parametrize(
@@ -109,6 +110,7 @@ def test_get_metadata_audio_file(metadata_getter):
109110
)
110111
assert best_audio_stream_metadata.bit_rate == 64000
111112
assert best_audio_stream_metadata.codec == "mp3"
113+
assert best_audio_stream_metadata.sample_format == "fltp"
112114

113115

114116
@pytest.mark.parametrize(

test/resources/sine_mono_s32.wav

250 KB
Binary file not shown.
Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
[
2+
{
3+
"duration_time": "0.064000",
4+
"pts_time": "0.000000"
5+
},
6+
{
7+
"duration_time": "0.064000",
8+
"pts_time": "0.064000"
9+
},
10+
{
11+
"duration_time": "0.064000",
12+
"pts_time": "0.128000"
13+
},
14+
{
15+
"duration_time": "0.064000",
16+
"pts_time": "0.192000"
17+
},
18+
{
19+
"duration_time": "0.064000",
20+
"pts_time": "0.256000"
21+
},
22+
{
23+
"duration_time": "0.064000",
24+
"pts_time": "0.320000"
25+
},
26+
{
27+
"duration_time": "0.064000",
28+
"pts_time": "0.384000"
29+
},
30+
{
31+
"duration_time": "0.064000",
32+
"pts_time": "0.448000"
33+
},
34+
{
35+
"duration_time": "0.064000",
36+
"pts_time": "0.512000"
37+
},
38+
{
39+
"duration_time": "0.064000",
40+
"pts_time": "0.576000"
41+
},
42+
{
43+
"duration_time": "0.064000",
44+
"pts_time": "0.640000"
45+
},
46+
{
47+
"duration_time": "0.064000",
48+
"pts_time": "0.704000"
49+
},
50+
{
51+
"duration_time": "0.064000",
52+
"pts_time": "0.768000"
53+
},
54+
{
55+
"duration_time": "0.064000",
56+
"pts_time": "0.832000"
57+
},
58+
{
59+
"duration_time": "0.064000",
60+
"pts_time": "0.896000"
61+
},
62+
{
63+
"duration_time": "0.064000",
64+
"pts_time": "0.960000"
65+
},
66+
{
67+
"duration_time": "0.064000",
68+
"pts_time": "1.024000"
69+
},
70+
{
71+
"duration_time": "0.064000",
72+
"pts_time": "1.088000"
73+
},
74+
{
75+
"duration_time": "0.064000",
76+
"pts_time": "1.152000"
77+
},
78+
{
79+
"duration_time": "0.064000",
80+
"pts_time": "1.216000"
81+
},
82+
{
83+
"duration_time": "0.064000",
84+
"pts_time": "1.280000"
85+
},
86+
{
87+
"duration_time": "0.064000",
88+
"pts_time": "1.344000"
89+
},
90+
{
91+
"duration_time": "0.064000",
92+
"pts_time": "1.408000"
93+
},
94+
{
95+
"duration_time": "0.064000",
96+
"pts_time": "1.472000"
97+
},
98+
{
99+
"duration_time": "0.064000",
100+
"pts_time": "1.536000"
101+
},
102+
{
103+
"duration_time": "0.064000",
104+
"pts_time": "1.600000"
105+
},
106+
{
107+
"duration_time": "0.064000",
108+
"pts_time": "1.664000"
109+
},
110+
{
111+
"duration_time": "0.064000",
112+
"pts_time": "1.728000"
113+
},
114+
{
115+
"duration_time": "0.064000",
116+
"pts_time": "1.792000"
117+
},
118+
{
119+
"duration_time": "0.064000",
120+
"pts_time": "1.856000"
121+
},
122+
{
123+
"duration_time": "0.064000",
124+
"pts_time": "1.920000"
125+
},
126+
{
127+
"duration_time": "0.064000",
128+
"pts_time": "1.984000"
129+
},
130+
{
131+
"duration_time": "0.064000",
132+
"pts_time": "2.048000"
133+
},
134+
{
135+
"duration_time": "0.064000",
136+
"pts_time": "2.112000"
137+
},
138+
{
139+
"duration_time": "0.064000",
140+
"pts_time": "2.176000"
141+
},
142+
{
143+
"duration_time": "0.064000",
144+
"pts_time": "2.240000"
145+
},
146+
{
147+
"duration_time": "0.064000",
148+
"pts_time": "2.304000"
149+
},
150+
{
151+
"duration_time": "0.064000",
152+
"pts_time": "2.368000"
153+
},
154+
{
155+
"duration_time": "0.064000",
156+
"pts_time": "2.432000"
157+
},
158+
{
159+
"duration_time": "0.064000",
160+
"pts_time": "2.496000"
161+
},
162+
{
163+
"duration_time": "0.064000",
164+
"pts_time": "2.560000"
165+
},
166+
{
167+
"duration_time": "0.064000",
168+
"pts_time": "2.624000"
169+
},
170+
{
171+
"duration_time": "0.064000",
172+
"pts_time": "2.688000"
173+
},
174+
{
175+
"duration_time": "0.064000",
176+
"pts_time": "2.752000"
177+
},
178+
{
179+
"duration_time": "0.064000",
180+
"pts_time": "2.816000"
181+
},
182+
{
183+
"duration_time": "0.064000",
184+
"pts_time": "2.880000"
185+
},
186+
{
187+
"duration_time": "0.064000",
188+
"pts_time": "2.944000"
189+
},
190+
{
191+
"duration_time": "0.064000",
192+
"pts_time": "3.008000"
193+
},
194+
{
195+
"duration_time": "0.064000",
196+
"pts_time": "3.072000"
197+
},
198+
{
199+
"duration_time": "0.064000",
200+
"pts_time": "3.136000"
201+
},
202+
{
203+
"duration_time": "0.064000",
204+
"pts_time": "3.200000"
205+
},
206+
{
207+
"duration_time": "0.064000",
208+
"pts_time": "3.264000"
209+
},
210+
{
211+
"duration_time": "0.064000",
212+
"pts_time": "3.328000"
213+
},
214+
{
215+
"duration_time": "0.064000",
216+
"pts_time": "3.392000"
217+
},
218+
{
219+
"duration_time": "0.064000",
220+
"pts_time": "3.456000"
221+
},
222+
{
223+
"duration_time": "0.064000",
224+
"pts_time": "3.520000"
225+
},
226+
{
227+
"duration_time": "0.064000",
228+
"pts_time": "3.584000"
229+
},
230+
{
231+
"duration_time": "0.064000",
232+
"pts_time": "3.648000"
233+
},
234+
{
235+
"duration_time": "0.064000",
236+
"pts_time": "3.712000"
237+
},
238+
{
239+
"duration_time": "0.064000",
240+
"pts_time": "3.776000"
241+
},
242+
{
243+
"duration_time": "0.064000",
244+
"pts_time": "3.840000"
245+
},
246+
{
247+
"duration_time": "0.064000",
248+
"pts_time": "3.904000"
249+
},
250+
{
251+
"duration_time": "0.032000",
252+
"pts_time": "3.968000"
253+
}
254+
]

0 commit comments

Comments
 (0)