@@ -917,6 +917,11 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
917
917
(stopPts <= lastDecodedAvFrameEnd);
918
918
}
919
919
920
+ torch::Tensor lastSamples = maybeFlushSwrBuffers ();
921
+ if (lastSamples.numel () > 0 ) {
922
+ frames.push_back (lastSamples);
923
+ }
924
+
920
925
return AudioFramesOutput{torch::cat (frames, 1 ), firstFramePtsSeconds};
921
926
}
922
927
@@ -1349,7 +1354,6 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
1349
1354
void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU (
1350
1355
UniqueAVFrame& srcAVFrame,
1351
1356
FrameOutput& frameOutput) {
1352
-
1353
1357
AVSampleFormat sourceSampleFormat =
1354
1358
static_cast <AVSampleFormat>(srcAVFrame->format );
1355
1359
AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP;
@@ -1395,6 +1399,7 @@ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
1395
1399
memcpy (
1396
1400
outputChannelData, avFrame->extended_data [channel], numBytesPerChannel);
1397
1401
}
1402
+
1398
1403
frameOutput.data = outputData;
1399
1404
}
1400
1405
@@ -1449,7 +1454,8 @@ UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormatAndSampleRate(
1449
1454
streamInfo.swrContext .get (),
1450
1455
convertedAVFrame->data ,
1451
1456
convertedAVFrame->nb_samples ,
1452
- static_cast <const uint8_t **>(const_cast <const uint8_t **>(srcAVFrame->data )),
1457
+ static_cast <const uint8_t **>(
1458
+ const_cast <const uint8_t **>(srcAVFrame->data )),
1453
1459
srcAVFrame->nb_samples );
1454
1460
TORCH_CHECK (
1455
1461
numConvertedSamples > 0 ,
@@ -1463,6 +1469,38 @@ UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormatAndSampleRate(
1463
1469
return convertedAVFrame;
1464
1470
}
1465
1471
1472
+ torch::Tensor VideoDecoder::maybeFlushSwrBuffers () {
1473
+ // When sample rate conversion is involved, swresample buffers some of the
1474
+ // samples in-between calls to swr_convert (see the libswresample docs).
1475
+ // That's because the last few samples in a given frame require future samples
1476
+ // from the next frame to be properly converted. This function flushes out the
1477
+ // samples that are stored in swresample's buffers.
1478
+ auto & streamInfo = streamInfos_[activeStreamIndex_];
1479
+ if (!streamInfo.swrContext ) {
1480
+ return torch::empty ({0 , 0 });
1481
+ }
1482
+ auto numRemainingSamples = // this is an upper bound
1483
+ swr_get_out_samples (streamInfo.swrContext .get (), 0 );
1484
+
1485
+ if (numRemainingSamples == 0 ) {
1486
+ return torch::empty ({0 , 0 });
1487
+ }
1488
+
1489
+ torch::Tensor lastSamples = torch::empty (
1490
+ {getNumChannels (streamInfo.codecContext ), numRemainingSamples},
1491
+ torch::kFloat32 );
1492
+ uint8_t * lastSamplesData = static_cast <uint8_t *>(lastSamples.data_ptr ());
1493
+
1494
+ auto actualNumRemainingSamples = swr_convert (
1495
+ streamInfo.swrContext .get (),
1496
+ &lastSamplesData,
1497
+ numRemainingSamples,
1498
+ NULL ,
1499
+ 0 );
1500
+ return lastSamples.narrow (
1501
+ /* dim=*/ 1 , /* start=*/ 0 , /* length=*/ actualNumRemainingSamples);
1502
+ }
1503
+
1466
1504
// --------------------------------------------------------------------------
1467
1505
// OUTPUT ALLOCATION AND SHAPE CONVERSION
1468
1506
// --------------------------------------------------------------------------
0 commit comments