Skip to content

Commit f858d0c

Browse files
committed
Add flushing
1 parent 6aa7b09 commit f858d0c

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -917,6 +917,11 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
917917
(stopPts <= lastDecodedAvFrameEnd);
918918
}
919919

920+
torch::Tensor lastSamples = maybeFlushSwrBuffers();
921+
if (lastSamples.numel() > 0) {
922+
frames.push_back(lastSamples);
923+
}
924+
920925
return AudioFramesOutput{torch::cat(frames, 1), firstFramePtsSeconds};
921926
}
922927

@@ -1349,7 +1354,6 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
13491354
void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
13501355
UniqueAVFrame& srcAVFrame,
13511356
FrameOutput& frameOutput) {
1352-
13531357
AVSampleFormat sourceSampleFormat =
13541358
static_cast<AVSampleFormat>(srcAVFrame->format);
13551359
AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP;
@@ -1395,6 +1399,7 @@ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
13951399
memcpy(
13961400
outputChannelData, avFrame->extended_data[channel], numBytesPerChannel);
13971401
}
1402+
13981403
frameOutput.data = outputData;
13991404
}
14001405

@@ -1449,7 +1454,8 @@ UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormatAndSampleRate(
14491454
streamInfo.swrContext.get(),
14501455
convertedAVFrame->data,
14511456
convertedAVFrame->nb_samples,
1452-
static_cast<const uint8_t**>(const_cast<const uint8_t**>(srcAVFrame->data)),
1457+
static_cast<const uint8_t**>(
1458+
const_cast<const uint8_t**>(srcAVFrame->data)),
14531459
srcAVFrame->nb_samples);
14541460
TORCH_CHECK(
14551461
numConvertedSamples > 0,
@@ -1463,6 +1469,38 @@ UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormatAndSampleRate(
14631469
return convertedAVFrame;
14641470
}
14651471

1472+
torch::Tensor VideoDecoder::maybeFlushSwrBuffers() {
1473+
// When sample rate conversion is involved, swresample buffers some of the
1474+
// samples in-between calls to swr_convert (see the libswresample docs).
1475+
// That's because the last few samples in a given frame require future samples
1476+
// from the next frame to be properly converted. This function flushes out the
1477+
// samples that are stored in swresample's buffers.
1478+
auto& streamInfo = streamInfos_[activeStreamIndex_];
1479+
if (!streamInfo.swrContext) {
1480+
return torch::empty({0, 0});
1481+
}
1482+
auto numRemainingSamples = // this is an upper bound
1483+
swr_get_out_samples(streamInfo.swrContext.get(), 0);
1484+
1485+
if (numRemainingSamples == 0) {
1486+
return torch::empty({0, 0});
1487+
}
1488+
1489+
torch::Tensor lastSamples = torch::empty(
1490+
{getNumChannels(streamInfo.codecContext), numRemainingSamples},
1491+
torch::kFloat32);
1492+
uint8_t* lastSamplesData = static_cast<uint8_t*>(lastSamples.data_ptr());
1493+
1494+
auto actualNumRemainingSamples = swr_convert(
1495+
streamInfo.swrContext.get(),
1496+
&lastSamplesData,
1497+
numRemainingSamples,
1498+
NULL,
1499+
0);
1500+
return lastSamples.narrow(
1501+
/*dim=*/1, /*start=*/0, /*length=*/actualNumRemainingSamples);
1502+
}
1503+
14661504
// --------------------------------------------------------------------------
14671505
// OUTPUT ALLOCATION AND SHAPE CONVERSION
14681506
// --------------------------------------------------------------------------

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,8 @@ class VideoDecoder {
407407
int sourceSampleRate,
408408
int desiredSampleRate);
409409

410+
torch::Tensor maybeFlushSwrBuffers();
411+
410412
// --------------------------------------------------------------------------
411413
// COLOR CONVERSION LIBRARIES HANDLERS CREATION
412414
// --------------------------------------------------------------------------

0 commit comments

Comments
 (0)