Skip to content

Commit db740a6

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into sample_rate
2 parents ef93be4 + 23c73ea commit db740a6

File tree

8 files changed

+62
-84
lines changed

8 files changed

+62
-84
lines changed

src/torchcodec/decoders/_core/CPUOnlyDevice.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ namespace facebook::torchcodec {
1717
void convertAVFrameToFrameOutputOnCuda(
1818
const torch::Device& device,
1919
[[maybe_unused]] const VideoDecoder::VideoStreamOptions& videoStreamOptions,
20-
[[maybe_unused]] VideoDecoder::AVFrameStream& avFrameStream,
20+
[[maybe_unused]] UniqueAVFrame& avFrame,
2121
[[maybe_unused]] VideoDecoder::FrameOutput& frameOutput,
2222
[[maybe_unused]] std::optional<torch::Tensor> preAllocatedOutputTensor) {
2323
throwUnsupportedDeviceError(device);

src/torchcodec/decoders/_core/CudaDevice.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,17 +190,15 @@ void initializeContextOnCuda(
190190
void convertAVFrameToFrameOutputOnCuda(
191191
const torch::Device& device,
192192
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
193-
VideoDecoder::AVFrameStream& avFrameStream,
193+
UniqueAVFrame& avFrame,
194194
VideoDecoder::FrameOutput& frameOutput,
195195
std::optional<torch::Tensor> preAllocatedOutputTensor) {
196-
AVFrame* avFrame = avFrameStream.avFrame.get();
197-
198196
TORCH_CHECK(
199197
avFrame->format == AV_PIX_FMT_CUDA,
200198
"Expected format to be AV_PIX_FMT_CUDA, got " +
201199
std::string(av_get_pix_fmt_name((AVPixelFormat)avFrame->format)));
202200
auto frameDims =
203-
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, *avFrame);
201+
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);
204202
int height = frameDims.height;
205203
int width = frameDims.width;
206204
torch::Tensor& dst = frameOutput.data;

src/torchcodec/decoders/_core/DeviceInterface.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ void initializeContextOnCuda(
3232
void convertAVFrameToFrameOutputOnCuda(
3333
const torch::Device& device,
3434
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
35-
VideoDecoder::AVFrameStream& avFrameStream,
35+
UniqueAVFrame& avFrame,
3636
VideoDecoder::FrameOutput& frameOutput,
3737
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
3838

src/torchcodec/decoders/_core/FFMPEGCommon.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,11 @@ std::string getFFMPEGErrorStringFromErrorCode(int errorCode) {
4848
return std::string(errorBuffer);
4949
}
5050

51-
int64_t getDuration(const UniqueAVFrame& frame) {
52-
return getDuration(frame.get());
53-
}
54-
55-
int64_t getDuration(const AVFrame* frame) {
51+
int64_t getDuration(const UniqueAVFrame& avFrame) {
5652
#if LIBAVUTIL_VERSION_MAJOR < 58
57-
return frame->pkt_duration;
53+
return avFrame->pkt_duration;
5854
#else
59-
return frame->duration;
55+
return avFrame->duration;
6056
#endif
6157
}
6258

src/torchcodec/decoders/_core/FFMPEGCommon.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ std::string getFFMPEGErrorStringFromErrorCode(int errorCode);
140140
// struct member representing duration has changed across the versions we
141141
// support.
142142
int64_t getDuration(const UniqueAVFrame& frame);
143-
int64_t getDuration(const AVFrame* frame);
144143

145144
int getNumChannels(const UniqueAVFrame& avFrame);
146145
int getNumChannels(const UniqueAVCodecContext& avCodecContext);

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 36 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -587,9 +587,9 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrame() {
587587
VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal(
588588
std::optional<torch::Tensor> preAllocatedOutputTensor) {
589589
validateActiveStream();
590-
AVFrameStream avFrameStream = decodeAVFrame(
591-
[this](AVFrame* avFrame) { return avFrame->pts >= cursor_; });
592-
return convertAVFrameToFrameOutput(avFrameStream, preAllocatedOutputTensor);
590+
UniqueAVFrame avFrame = decodeAVFrame(
591+
[this](const UniqueAVFrame& avFrame) { return avFrame->pts >= cursor_; });
592+
return convertAVFrameToFrameOutput(avFrame, preAllocatedOutputTensor);
593593
}
594594

595595
VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndex(int64_t frameIndex) {
@@ -719,8 +719,8 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) {
719719
}
720720

721721
setCursorPtsInSeconds(seconds);
722-
AVFrameStream avFrameStream =
723-
decodeAVFrame([seconds, this](AVFrame* avFrame) {
722+
UniqueAVFrame avFrame =
723+
decodeAVFrame([seconds, this](const UniqueAVFrame& avFrame) {
724724
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
725725
double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase);
726726
double frameEndTime = ptsToSeconds(
@@ -739,7 +739,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) {
739739
});
740740

741741
// Convert the frame to tensor.
742-
FrameOutput frameOutput = convertAVFrameToFrameOutput(avFrameStream);
742+
FrameOutput frameOutput = convertAVFrameToFrameOutput(avFrame);
743743
frameOutput.data = maybePermuteHWC2CHW(frameOutput.data);
744744
return frameOutput;
745745
}
@@ -895,14 +895,11 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
895895
auto finished = false;
896896
while (!finished) {
897897
try {
898-
AVFrameStream avFrameStream = decodeAVFrame([startPts](AVFrame* avFrame) {
899-
return startPts < avFrame->pts + getDuration(avFrame);
900-
});
901-
// TODO: it's not great that we are getting a FrameOutput, which is
902-
// intended for videos. We should consider bypassing
903-
// convertAVFrameToFrameOutput and directly call
904-
// convertAudioAVFrameToFrameOutputOnCPU.
905-
auto frameOutput = convertAVFrameToFrameOutput(avFrameStream);
898+
UniqueAVFrame avFrame =
899+
decodeAVFrame([startPts](const UniqueAVFrame& avFrame) {
900+
return startPts < avFrame->pts + getDuration(avFrame);
901+
});
902+
auto frameOutput = convertAVFrameToFrameOutput(avFrame);
906903
firstFramePtsSeconds =
907904
std::min(firstFramePtsSeconds, frameOutput.ptsSeconds);
908905
frames.push_back(frameOutput.data);
@@ -1039,8 +1036,8 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
10391036
// LOW-LEVEL DECODING
10401037
// --------------------------------------------------------------------------
10411038

1042-
VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
1043-
std::function<bool(AVFrame*)> filterFunction) {
1039+
UniqueAVFrame VideoDecoder::decodeAVFrame(
1040+
std::function<bool(const UniqueAVFrame&)> filterFunction) {
10441041
validateActiveStream();
10451042

10461043
resetDecodeStats();
@@ -1068,7 +1065,7 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
10681065

10691066
decodeStats_.numFramesReceivedByDecoder++;
10701067
// Is this the kind of frame we're looking for?
1071-
if (status == AVSUCCESS && filterFunction(avFrame.get())) {
1068+
if (status == AVSUCCESS && filterFunction(avFrame)) {
10721069
// Yes, this is the frame we'll return; break out of the decoding loop.
10731070
break;
10741071
} else if (status == AVSUCCESS) {
@@ -1154,37 +1151,35 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
11541151
streamInfo.lastDecodedAvFramePts = avFrame->pts;
11551152
streamInfo.lastDecodedAvFrameDuration = getDuration(avFrame);
11561153

1157-
return AVFrameStream(std::move(avFrame), activeStreamIndex_);
1154+
return avFrame;
11581155
}
11591156

11601157
// --------------------------------------------------------------------------
11611158
// AVFRAME <-> FRAME OUTPUT CONVERSION
11621159
// --------------------------------------------------------------------------
11631160

11641161
VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
1165-
VideoDecoder::AVFrameStream& avFrameStream,
1162+
UniqueAVFrame& avFrame,
11661163
std::optional<torch::Tensor> preAllocatedOutputTensor) {
11671164
// Convert the frame to tensor.
11681165
FrameOutput frameOutput;
1169-
int streamIndex = avFrameStream.streamIndex;
1170-
AVFrame* avFrame = avFrameStream.avFrame.get();
1171-
frameOutput.streamIndex = streamIndex;
1172-
auto& streamInfo = streamInfos_[streamIndex];
1166+
auto& streamInfo = streamInfos_[activeStreamIndex_];
11731167
frameOutput.ptsSeconds = ptsToSeconds(
1174-
avFrame->pts, formatContext_->streams[streamIndex]->time_base);
1168+
avFrame->pts, formatContext_->streams[activeStreamIndex_]->time_base);
11751169
frameOutput.durationSeconds = ptsToSeconds(
1176-
getDuration(avFrame), formatContext_->streams[streamIndex]->time_base);
1170+
getDuration(avFrame),
1171+
formatContext_->streams[activeStreamIndex_]->time_base);
11771172
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
11781173
convertAudioAVFrameToFrameOutputOnCPU(
1179-
avFrameStream, frameOutput, preAllocatedOutputTensor);
1174+
avFrame, frameOutput, preAllocatedOutputTensor);
11801175
} else if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) {
11811176
convertAVFrameToFrameOutputOnCPU(
1182-
avFrameStream, frameOutput, preAllocatedOutputTensor);
1177+
avFrame, frameOutput, preAllocatedOutputTensor);
11831178
} else if (streamInfo.videoStreamOptions.device.type() == torch::kCUDA) {
11841179
convertAVFrameToFrameOutputOnCuda(
11851180
streamInfo.videoStreamOptions.device,
11861181
streamInfo.videoStreamOptions,
1187-
avFrameStream,
1182+
avFrame,
11881183
frameOutput,
11891184
preAllocatedOutputTensor);
11901185
} else {
@@ -1205,14 +1200,13 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
12051200
// Dimension order of the preAllocatedOutputTensor must be HWC, regardless of
12061201
// `dimension_order` parameter. It's up to callers to re-shape it if needed.
12071202
void VideoDecoder::convertAVFrameToFrameOutputOnCPU(
1208-
VideoDecoder::AVFrameStream& avFrameStream,
1203+
UniqueAVFrame& avFrame,
12091204
FrameOutput& frameOutput,
12101205
std::optional<torch::Tensor> preAllocatedOutputTensor) {
1211-
AVFrame* avFrame = avFrameStream.avFrame.get();
12121206
auto& streamInfo = streamInfos_[activeStreamIndex_];
12131207

12141208
auto frameDims = getHeightAndWidthFromOptionsOrAVFrame(
1215-
streamInfo.videoStreamOptions, *avFrame);
1209+
streamInfo.videoStreamOptions, avFrame);
12161210
int expectedOutputHeight = frameDims.height;
12171211
int expectedOutputWidth = frameDims.width;
12181212

@@ -1306,7 +1300,7 @@ void VideoDecoder::convertAVFrameToFrameOutputOnCPU(
13061300
}
13071301

13081302
int VideoDecoder::convertAVFrameToTensorUsingSwsScale(
1309-
const AVFrame* avFrame,
1303+
const UniqueAVFrame& avFrame,
13101304
torch::Tensor& outputTensor) {
13111305
StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_];
13121306
SwsContext* swsContext = activeStreamInfo.swsContext.get();
@@ -1326,11 +1320,11 @@ int VideoDecoder::convertAVFrameToTensorUsingSwsScale(
13261320
}
13271321

13281322
torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
1329-
const AVFrame* avFrame) {
1323+
const UniqueAVFrame& avFrame) {
13301324
FilterGraphContext& filterGraphContext =
13311325
streamInfos_[activeStreamIndex_].filterGraphContext;
13321326
int status =
1333-
av_buffersrc_write_frame(filterGraphContext.sourceContext, avFrame);
1327+
av_buffersrc_write_frame(filterGraphContext.sourceContext, avFrame.get());
13341328
if (status < AVSUCCESS) {
13351329
throw std::runtime_error("Failed to add frame to buffer source context");
13361330
}
@@ -1354,18 +1348,18 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
13541348
}
13551349

13561350
void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
1357-
VideoDecoder::AVFrameStream& avFrameStream,
1351+
UniqueAVFrame& srcAVFrame,
13581352
FrameOutput& frameOutput,
13591353
std::optional<torch::Tensor> preAllocatedOutputTensor) {
13601354
TORCH_CHECK(
13611355
!preAllocatedOutputTensor.has_value(),
13621356
"pre-allocated audio tensor not supported yet.");
13631357

13641358
AVSampleFormat sourceSampleFormat =
1365-
static_cast<AVSampleFormat>(avFrameStream.avFrame->format);
1359+
static_cast<AVSampleFormat>(srcAVFrame->format);
13661360
AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP;
13671361

1368-
int sourceSampleRate = avFrameStream.avFrame->sample_rate;
1362+
int sourceSampleRate = srcAVFrame->sample_rate;
13691363
int desiredSampleRate =
13701364
streamInfos_[activeStreamIndex_].audioStreamOptions.sampleRate.value_or(
13711365
sourceSampleRate);
@@ -1377,14 +1371,13 @@ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
13771371
UniqueAVFrame convertedAVFrame;
13781372
if (mustConvert) {
13791373
convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate(
1380-
avFrameStream.avFrame,
1374+
srcAVFrame,
13811375
sourceSampleFormat,
13821376
desiredSampleFormat,
13831377
sourceSampleRate,
13841378
desiredSampleRate);
13851379
}
1386-
const UniqueAVFrame& avFrame =
1387-
mustConvert ? convertedAVFrame : avFrameStream.avFrame;
1380+
const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame;
13881381

13891382
AVSampleFormat format = static_cast<AVSampleFormat>(avFrame->format);
13901383
TORCH_CHECK(
@@ -1981,10 +1974,10 @@ FrameDims getHeightAndWidthFromOptionsOrMetadata(
19811974

19821975
FrameDims getHeightAndWidthFromOptionsOrAVFrame(
19831976
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
1984-
const AVFrame& avFrame) {
1977+
const UniqueAVFrame& avFrame) {
19851978
return FrameDims(
1986-
videoStreamOptions.height.value_or(avFrame.height),
1987-
videoStreamOptions.width.value_or(avFrame.width));
1979+
videoStreamOptions.height.value_or(avFrame->height),
1980+
videoStreamOptions.width.value_or(avFrame->width));
19881981
}
19891982

19901983
} // namespace facebook::torchcodec

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,13 @@ class VideoDecoder {
161161
// They are the equivalent of the user-facing Frame and FrameBatch classes in
162162
// Python. They contain RGB decoded frames along with some associated data
163163
// like PTS and duration.
164+
// FrameOutput is also relevant for audio decoding, typically as the output of
165+
// getNextFrame(), or as a temporary output variable.
164166
struct FrameOutput {
165-
torch::Tensor data; // 3D: of shape CHW or HWC.
166-
int streamIndex;
167+
// data shape is:
168+
// - 3D (C, H, W) or (H, W, C) for videos
169+
// - 2D (numChannels, numSamples) for audio
170+
torch::Tensor data;
167171
double ptsSeconds;
168172
double durationSeconds;
169173
};
@@ -252,23 +256,6 @@ class VideoDecoder {
252256
// These are APIs that should be private, but that are effectively exposed for
253257
// practical reasons, typically for testing purposes.
254258

255-
// This struct is needed because AVFrame doesn't retain the streamIndex. Only
256-
// the AVPacket knows its stream. This is what the low-level private decoding
257-
// entry points return. The AVFrameStream is then converted to a FrameOutput
258-
// with convertAVFrameToFrameOutput. It should be private, but is currently
259-
// used by DeviceInterface.
260-
struct AVFrameStream {
261-
// The actual decoded output as a unique pointer to an AVFrame.
262-
// Usually, this is a YUV frame. It'll be converted to RGB in
263-
// convertAVFrameToFrameOutput.
264-
UniqueAVFrame avFrame;
265-
// The stream index of the decoded frame.
266-
int streamIndex;
267-
268-
explicit AVFrameStream(UniqueAVFrame&& a, int s)
269-
: avFrame(std::move(a)), streamIndex(s) {}
270-
};
271-
272259
// Once getFrameAtIndex supports the preAllocatedOutputTensor parameter, we
273260
// can move it back to private.
274261
FrameOutput getFrameAtIndexInternal(
@@ -385,31 +372,33 @@ class VideoDecoder {
385372

386373
void maybeSeekToBeforeDesiredPts();
387374

388-
AVFrameStream decodeAVFrame(std::function<bool(AVFrame*)> filterFunction);
375+
UniqueAVFrame decodeAVFrame(
376+
std::function<bool(const UniqueAVFrame&)> filterFunction);
389377

390378
FrameOutput getNextFrameInternal(
391379
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
392380

393381
torch::Tensor maybePermuteHWC2CHW(torch::Tensor& hwcTensor);
394382

395383
FrameOutput convertAVFrameToFrameOutput(
396-
AVFrameStream& avFrameStream,
384+
UniqueAVFrame& avFrame,
397385
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
398386

399387
void convertAVFrameToFrameOutputOnCPU(
400-
AVFrameStream& avFrameStream,
388+
UniqueAVFrame& avFrame,
401389
FrameOutput& frameOutput,
402390
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
403391

404392
void convertAudioAVFrameToFrameOutputOnCPU(
405-
AVFrameStream& avFrameStream,
393+
UniqueAVFrame& srcAVFrame,
406394
FrameOutput& frameOutput,
407395
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
408396

409-
torch::Tensor convertAVFrameToTensorUsingFilterGraph(const AVFrame* avFrame);
397+
torch::Tensor convertAVFrameToTensorUsingFilterGraph(
398+
const UniqueAVFrame& avFrame);
410399

411400
int convertAVFrameToTensorUsingSwsScale(
412-
const AVFrame* avFrame,
401+
const UniqueAVFrame& avFrame,
413402
torch::Tensor& outputTensor);
414403

415404
UniqueAVFrame convertAudioAVFrameSampleFormatAndSampleRate(
@@ -580,7 +569,7 @@ FrameDims getHeightAndWidthFromOptionsOrMetadata(
580569

581570
FrameDims getHeightAndWidthFromOptionsOrAVFrame(
582571
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
583-
const AVFrame& avFrame);
572+
const UniqueAVFrame& avFrame);
584573

585574
torch::Tensor allocateEmptyHWCTensor(
586575
int height,

test/decoders/test_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,12 +650,15 @@ def test_next(self, asset):
650650
frame_index = 0
651651
while True:
652652
try:
653-
frame, *_ = get_next_frame(decoder)
653+
frame, pts_seconds, duration_seconds = get_next_frame(decoder)
654654
except IndexError:
655655
break
656656
torch.testing.assert_close(
657657
frame, asset.get_frame_data_by_index(frame_index)
658658
)
659+
frame_info = asset.get_frame_info(frame_index)
660+
assert pts_seconds == frame_info.pts_seconds
661+
assert duration_seconds == frame_info.duration_seconds
659662
frame_index += 1
660663

661664
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)