@@ -587,9 +587,9 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrame() {
587
587
VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal (
588
588
std::optional<torch::Tensor> preAllocatedOutputTensor) {
589
589
validateActiveStream ();
590
- AVFrameStream avFrameStream = decodeAVFrame (
591
- [this ](AVFrame* avFrame) { return avFrame->pts >= cursor_; });
592
- return convertAVFrameToFrameOutput (avFrameStream , preAllocatedOutputTensor);
590
+ UniqueAVFrame avFrame = decodeAVFrame (
591
+ [this ](const UniqueAVFrame& avFrame) { return avFrame->pts >= cursor_; });
592
+ return convertAVFrameToFrameOutput (avFrame , preAllocatedOutputTensor);
593
593
}
594
594
595
595
VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndex (int64_t frameIndex) {
@@ -719,8 +719,8 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) {
719
719
}
720
720
721
721
setCursorPtsInSeconds (seconds);
722
- AVFrameStream avFrameStream =
723
- decodeAVFrame ([seconds, this ](AVFrame* avFrame) {
722
+ UniqueAVFrame avFrame =
723
+ decodeAVFrame ([seconds, this ](const UniqueAVFrame& avFrame) {
724
724
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
725
725
double frameStartTime = ptsToSeconds (avFrame->pts , streamInfo.timeBase );
726
726
double frameEndTime = ptsToSeconds (
@@ -739,7 +739,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) {
739
739
});
740
740
741
741
// Convert the frame to tensor.
742
- FrameOutput frameOutput = convertAVFrameToFrameOutput (avFrameStream );
742
+ FrameOutput frameOutput = convertAVFrameToFrameOutput (avFrame );
743
743
frameOutput.data = maybePermuteHWC2CHW (frameOutput.data );
744
744
return frameOutput;
745
745
}
@@ -895,14 +895,11 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
895
895
auto finished = false ;
896
896
while (!finished) {
897
897
try {
898
- AVFrameStream avFrameStream = decodeAVFrame ([startPts](AVFrame* avFrame) {
899
- return startPts < avFrame->pts + getDuration (avFrame);
900
- });
901
- // TODO: it's not great that we are getting a FrameOutput, which is
902
- // intended for videos. We should consider bypassing
903
- // convertAVFrameToFrameOutput and directly call
904
- // convertAudioAVFrameToFrameOutputOnCPU.
905
- auto frameOutput = convertAVFrameToFrameOutput (avFrameStream);
898
+ UniqueAVFrame avFrame =
899
+ decodeAVFrame ([startPts](const UniqueAVFrame& avFrame) {
900
+ return startPts < avFrame->pts + getDuration (avFrame);
901
+ });
902
+ auto frameOutput = convertAVFrameToFrameOutput (avFrame);
906
903
firstFramePtsSeconds =
907
904
std::min (firstFramePtsSeconds, frameOutput.ptsSeconds );
908
905
frames.push_back (frameOutput.data );
@@ -1039,8 +1036,8 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
1039
1036
// LOW-LEVEL DECODING
1040
1037
// --------------------------------------------------------------------------
1041
1038
1042
- VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame (
1043
- std::function<bool (AVFrame* )> filterFunction) {
1039
+ UniqueAVFrame VideoDecoder::decodeAVFrame (
1040
+ std::function<bool (const UniqueAVFrame& )> filterFunction) {
1044
1041
validateActiveStream ();
1045
1042
1046
1043
resetDecodeStats ();
@@ -1068,7 +1065,7 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
1068
1065
1069
1066
decodeStats_.numFramesReceivedByDecoder ++;
1070
1067
// Is this the kind of frame we're looking for?
1071
- if (status == AVSUCCESS && filterFunction (avFrame. get () )) {
1068
+ if (status == AVSUCCESS && filterFunction (avFrame)) {
1072
1069
// Yes, this is the frame we'll return; break out of the decoding loop.
1073
1070
break ;
1074
1071
} else if (status == AVSUCCESS) {
@@ -1154,37 +1151,35 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
1154
1151
streamInfo.lastDecodedAvFramePts = avFrame->pts ;
1155
1152
streamInfo.lastDecodedAvFrameDuration = getDuration (avFrame);
1156
1153
1157
- return AVFrameStream ( std::move ( avFrame), activeStreamIndex_) ;
1154
+ return avFrame;
1158
1155
}
1159
1156
1160
1157
// --------------------------------------------------------------------------
1161
1158
// AVFRAME <-> FRAME OUTPUT CONVERSION
1162
1159
// --------------------------------------------------------------------------
1163
1160
1164
1161
VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput (
1165
- VideoDecoder::AVFrameStream& avFrameStream ,
1162
+ UniqueAVFrame& avFrame ,
1166
1163
std::optional<torch::Tensor> preAllocatedOutputTensor) {
1167
1164
// Convert the frame to tensor.
1168
1165
FrameOutput frameOutput;
1169
- int streamIndex = avFrameStream.streamIndex ;
1170
- AVFrame* avFrame = avFrameStream.avFrame .get ();
1171
- frameOutput.streamIndex = streamIndex;
1172
- auto & streamInfo = streamInfos_[streamIndex];
1166
+ auto & streamInfo = streamInfos_[activeStreamIndex_];
1173
1167
frameOutput.ptsSeconds = ptsToSeconds (
1174
- avFrame->pts , formatContext_->streams [streamIndex ]->time_base );
1168
+ avFrame->pts , formatContext_->streams [activeStreamIndex_ ]->time_base );
1175
1169
frameOutput.durationSeconds = ptsToSeconds (
1176
- getDuration (avFrame), formatContext_->streams [streamIndex]->time_base );
1170
+ getDuration (avFrame),
1171
+ formatContext_->streams [activeStreamIndex_]->time_base );
1177
1172
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
1178
1173
convertAudioAVFrameToFrameOutputOnCPU (
1179
- avFrameStream , frameOutput, preAllocatedOutputTensor);
1174
+ avFrame , frameOutput, preAllocatedOutputTensor);
1180
1175
} else if (streamInfo.videoStreamOptions .device .type () == torch::kCPU ) {
1181
1176
convertAVFrameToFrameOutputOnCPU (
1182
- avFrameStream , frameOutput, preAllocatedOutputTensor);
1177
+ avFrame , frameOutput, preAllocatedOutputTensor);
1183
1178
} else if (streamInfo.videoStreamOptions .device .type () == torch::kCUDA ) {
1184
1179
convertAVFrameToFrameOutputOnCuda (
1185
1180
streamInfo.videoStreamOptions .device ,
1186
1181
streamInfo.videoStreamOptions ,
1187
- avFrameStream ,
1182
+ avFrame ,
1188
1183
frameOutput,
1189
1184
preAllocatedOutputTensor);
1190
1185
} else {
@@ -1205,14 +1200,13 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
1205
1200
// Dimension order of the preAllocatedOutputTensor must be HWC, regardless of
1206
1201
// `dimension_order` parameter. It's up to callers to re-shape it if needed.
1207
1202
void VideoDecoder::convertAVFrameToFrameOutputOnCPU (
1208
- VideoDecoder::AVFrameStream& avFrameStream ,
1203
+ UniqueAVFrame& avFrame ,
1209
1204
FrameOutput& frameOutput,
1210
1205
std::optional<torch::Tensor> preAllocatedOutputTensor) {
1211
- AVFrame* avFrame = avFrameStream.avFrame .get ();
1212
1206
auto & streamInfo = streamInfos_[activeStreamIndex_];
1213
1207
1214
1208
auto frameDims = getHeightAndWidthFromOptionsOrAVFrame (
1215
- streamInfo.videoStreamOptions , * avFrame);
1209
+ streamInfo.videoStreamOptions , avFrame);
1216
1210
int expectedOutputHeight = frameDims.height ;
1217
1211
int expectedOutputWidth = frameDims.width ;
1218
1212
@@ -1306,7 +1300,7 @@ void VideoDecoder::convertAVFrameToFrameOutputOnCPU(
1306
1300
}
1307
1301
1308
1302
int VideoDecoder::convertAVFrameToTensorUsingSwsScale (
1309
- const AVFrame* avFrame,
1303
+ const UniqueAVFrame& avFrame,
1310
1304
torch::Tensor& outputTensor) {
1311
1305
StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_];
1312
1306
SwsContext* swsContext = activeStreamInfo.swsContext .get ();
@@ -1326,11 +1320,11 @@ int VideoDecoder::convertAVFrameToTensorUsingSwsScale(
1326
1320
}
1327
1321
1328
1322
torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph (
1329
- const AVFrame* avFrame) {
1323
+ const UniqueAVFrame& avFrame) {
1330
1324
FilterGraphContext& filterGraphContext =
1331
1325
streamInfos_[activeStreamIndex_].filterGraphContext ;
1332
1326
int status =
1333
- av_buffersrc_write_frame (filterGraphContext.sourceContext , avFrame);
1327
+ av_buffersrc_write_frame (filterGraphContext.sourceContext , avFrame. get () );
1334
1328
if (status < AVSUCCESS) {
1335
1329
throw std::runtime_error (" Failed to add frame to buffer source context" );
1336
1330
}
@@ -1354,18 +1348,18 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
1354
1348
}
1355
1349
1356
1350
void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU (
1357
- VideoDecoder::AVFrameStream& avFrameStream ,
1351
+ UniqueAVFrame& srcAVFrame ,
1358
1352
FrameOutput& frameOutput,
1359
1353
std::optional<torch::Tensor> preAllocatedOutputTensor) {
1360
1354
TORCH_CHECK (
1361
1355
!preAllocatedOutputTensor.has_value (),
1362
1356
" pre-allocated audio tensor not supported yet." );
1363
1357
1364
1358
AVSampleFormat sourceSampleFormat =
1365
- static_cast <AVSampleFormat>(avFrameStream. avFrame ->format );
1359
+ static_cast <AVSampleFormat>(srcAVFrame ->format );
1366
1360
AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP;
1367
1361
1368
- int sourceSampleRate = avFrameStream. avFrame ->sample_rate ;
1362
+ int sourceSampleRate = srcAVFrame ->sample_rate ;
1369
1363
int desiredSampleRate =
1370
1364
streamInfos_[activeStreamIndex_].audioStreamOptions .sampleRate .value_or (
1371
1365
sourceSampleRate);
@@ -1377,14 +1371,13 @@ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
1377
1371
UniqueAVFrame convertedAVFrame;
1378
1372
if (mustConvert) {
1379
1373
convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate (
1380
- avFrameStream. avFrame ,
1374
+ srcAVFrame ,
1381
1375
sourceSampleFormat,
1382
1376
desiredSampleFormat,
1383
1377
sourceSampleRate,
1384
1378
desiredSampleRate);
1385
1379
}
1386
- const UniqueAVFrame& avFrame =
1387
- mustConvert ? convertedAVFrame : avFrameStream.avFrame ;
1380
+ const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame;
1388
1381
1389
1382
AVSampleFormat format = static_cast <AVSampleFormat>(avFrame->format );
1390
1383
TORCH_CHECK (
@@ -1981,10 +1974,10 @@ FrameDims getHeightAndWidthFromOptionsOrMetadata(
1981
1974
1982
1975
FrameDims getHeightAndWidthFromOptionsOrAVFrame (
1983
1976
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
1984
- const AVFrame & avFrame) {
1977
+ const UniqueAVFrame & avFrame) {
1985
1978
return FrameDims (
1986
- videoStreamOptions.height .value_or (avFrame. height ),
1987
- videoStreamOptions.width .value_or (avFrame. width ));
1979
+ videoStreamOptions.height .value_or (avFrame-> height ),
1980
+ videoStreamOptions.width .value_or (avFrame-> width ));
1988
1981
}
1989
1982
1990
1983
} // namespace facebook::torchcodec
0 commit comments