Skip to content

Commit 298c9c1

Browse files
authored
Rough implementation of getting key frame indices (#484)
1 parent f46f64a commit 298c9c1

File tree

8 files changed

+113
-5
lines changed

8 files changed

+113
-5
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,20 @@ VideoDecoder::ContainerMetadata VideoDecoder::getContainerMetadata() const {
538538
return containerMetadata_;
539539
}
540540

541+
torch::Tensor VideoDecoder::getKeyFrameIndices(int streamIndex) {
542+
validateUserProvidedStreamIndex(streamIndex);
543+
validateScannedAllStreams("getKeyFrameIndices");
544+
545+
const std::vector<FrameInfo>& keyFrames = streamInfos_[streamIndex].keyFrames;
546+
torch::Tensor keyFrameIndices =
547+
torch::empty({static_cast<int64_t>(keyFrames.size())}, {torch::kInt64});
548+
for (size_t i = 0; i < keyFrames.size(); ++i) {
549+
keyFrameIndices[i] = keyFrames[i].frameIndex;
550+
}
551+
552+
return keyFrameIndices;
553+
}
554+
541555
int VideoDecoder::getKeyFrameIndexForPtsUsingEncoderIndex(
542556
AVStream* stream,
543557
int64_t pts) const {
@@ -654,7 +668,21 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
654668
return frameInfo1.pts < frameInfo2.pts;
655669
});
656670

671+
size_t keyIndex = 0;
657672
for (size_t i = 0; i < streamInfo.allFrames.size(); ++i) {
673+
streamInfo.allFrames[i].frameIndex = i;
674+
675+
// For correctly encoded files, we shouldn't need to ensure that keyIndex
676+
// is less than the number of key frames. That is, the relationship
677+
// between the frames in allFrames and keyFrames should be such that
678+
// keyIndex is always a valid index into keyFrames. But we're being
679+
// defensive in case we encounter incorrectly encoded files.
680+
if (keyIndex < streamInfo.keyFrames.size() &&
681+
streamInfo.keyFrames[keyIndex].pts == streamInfo.allFrames[i].pts) {
682+
streamInfo.keyFrames[keyIndex].frameIndex = i;
683+
++keyIndex;
684+
}
685+
658686
if (i + 1 < streamInfo.allFrames.size()) {
659687
streamInfo.allFrames[i].nextPts = streamInfo.allFrames[i + 1].pts;
660688
}

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ class VideoDecoder {
9797
// Returns the metadata for the container.
9898
ContainerMetadata getContainerMetadata() const;
9999

100+
// Returns the key frame indices as a tensor. The tensor is 1D and contains
101+
// int64 values, where each value is the frame index for a key frame.
102+
torch::Tensor getKeyFrameIndices(int streamIndex);
103+
100104
// --------------------------------------------------------------------------
101105
// ADDING STREAMS API
102106
// --------------------------------------------------------------------------
@@ -284,12 +288,19 @@ class VideoDecoder {
284288

285289
struct FrameInfo {
286290
int64_t pts = 0;
287-
// The value of this default is important: the last frame's nextPts will be
288-
// INT64_MAX, which ensures that the allFrames vec contains FrameInfo
289-
// structs with *increasing* nextPts values. That's a necessary condition
290-
// for the binary searches on those values to work properly (as typically
291-
// done during pts -> index conversions.)
291+
292+
// The value of the nextPts default is important: the last frame's nextPts
293+
// will be INT64_MAX, which ensures that the allFrames vec contains
294+
// FrameInfo structs with *increasing* nextPts values. That's a necessary
295+
// condition for the binary searches on those values to work properly (as
296+
// typically done during pts -> index conversions).
292297
int64_t nextPts = INT64_MAX;
298+
299+
// Note that frameIndex is ALWAYS the index into all of the frames in that
300+
// stream, even when the FrameInfo is part of the key frame index. Given a
301+
// FrameInfo for a key frame, the frameIndex allows us to know which frame
302+
// that is in the stream.
303+
int64_t frameIndex = 0;
293304
};
294305

295306
struct FilterGraphContext {

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4848
"get_frames_by_pts_in_range(Tensor(a!) decoder, *, int stream_index, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)");
4949
m.def(
5050
"get_frames_by_pts(Tensor(a!) decoder, *, int stream_index, float[] timestamps) -> (Tensor, Tensor, Tensor)");
51+
m.def(
52+
"_get_key_frame_indices(Tensor(a!) decoder, int stream_index) -> Tensor");
5153
m.def("get_json_metadata(Tensor(a!) decoder) -> str");
5254
m.def("get_container_json_metadata(Tensor(a!) decoder) -> str");
5355
m.def(
@@ -334,6 +336,13 @@ bool _test_frame_pts_equality(
334336
videoDecoder->getPtsSecondsForFrame(stream_index, frame_index);
335337
}
336338

339+
torch::Tensor _get_key_frame_indices(
340+
at::Tensor& decoder,
341+
int64_t stream_index) {
342+
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
343+
return videoDecoder->getKeyFrameIndices(stream_index);
344+
}
345+
337346
std::string get_json_metadata(at::Tensor& decoder) {
338347
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
339348

@@ -526,6 +535,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
526535
m.impl("add_video_stream", &add_video_stream);
527536
m.impl("_add_video_stream", &_add_video_stream);
528537
m.impl("get_next_frame", &get_next_frame);
538+
m.impl("_get_key_frame_indices", &_get_key_frame_indices);
529539
m.impl("get_json_metadata", &get_json_metadata);
530540
m.impl("get_container_json_metadata", &get_container_json_metadata);
531541
m.impl("get_stream_json_metadata", &get_stream_json_metadata);

src/torchcodec/decoders/_core/VideoDecoderOps.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ bool _test_frame_pts_equality(
137137
int64_t frame_index,
138138
double pts_seconds_to_test);
139139

140+
torch::Tensor _get_key_frame_indices(at::Tensor& decoder, int64_t stream_index);
141+
140142
// Get the metadata from the video as a string.
141143
std::string get_json_metadata(at::Tensor& decoder);
142144

src/torchcodec/decoders/_core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from .video_decoder_ops import (
1515
_add_video_stream,
16+
_get_key_frame_indices,
1617
_test_frame_pts_equality,
1718
add_video_stream,
1819
create_from_bytes,

src/torchcodec/decoders/_core/video_decoder_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def load_torchcodec_extension():
8282
_get_container_json_metadata = (
8383
torch.ops.torchcodec_ns.get_container_json_metadata.default
8484
)
85+
_get_key_frame_indices = torch.ops.torchcodec_ns._get_key_frame_indices.default
8586
scan_all_streams_to_update_metadata = (
8687
torch.ops.torchcodec_ns.scan_all_streams_to_update_metadata.default
8788
)
@@ -255,6 +256,13 @@ def get_frames_by_pts_in_range_abstract(
255256
)
256257

257258

259+
@register_fake("torchcodec_ns::_get_key_frame_indices")
260+
def get_key_frame_indices_abstract(
261+
decoder: torch.Tensor, *, stream_index: int
262+
) -> torch.Tensor:
263+
return torch.empty([], dtype=torch.int)
264+
265+
258266
@register_fake("torchcodec_ns::get_json_metadata")
259267
def get_json_metadata_abstract(decoder: torch.Tensor) -> str:
260268
return ""

src/torchcodec/decoders/_video_decoder.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,11 @@ def __getitem__(self, key: Union[numbers.Integral, slice]) -> Tensor:
185185
f"Unsupported key type: {type(key)}. Supported types are int and slice."
186186
)
187187

188+
def _get_key_frame_indices(self) -> list[int]:
189+
return core._get_key_frame_indices(
190+
self._decoder, stream_index=self.stream_index
191+
)
192+
188193
def get_frame_at(self, index: int) -> Frame:
189194
"""Return a single frame at the given index.
190195

test/decoders/test_video_decoder.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,49 @@ def test_get_frames_by_pts_in_range_fails(self, device, seek_mode):
831831
with pytest.raises(ValueError, match="Invalid stop seconds"):
832832
frame = decoder.get_frames_played_in_range(0, 23) # noqa
833833

834+
@pytest.mark.parametrize("device", cpu_and_cuda())
835+
def test_get_key_frame_indices(self, device):
836+
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode="exact")
837+
key_frame_indices = decoder._get_key_frame_indices()
838+
839+
# The key frame indices were generated from the following command:
840+
# $ ffprobe -v error -hide_banner -select_streams v:1 -show_frames -of csv test/resources/nasa_13013.mp4 | grep -n ",I," | cut -d ':' -f 1 > key_frames.txt
841+
# What it's doing:
842+
# 1. Calling ffprobe on the second video stream, which is absolute stream index 3.
843+
# 2. Showing all frames for that stream.
844+
# 3. Using grep to find the "I" frames, which are the key frames. We also get the line
845+
# number, which is also the count of the rames.
846+
# 4. Using cut to extract just the count for the frame.
847+
# Finally, because the above produces a count, which is index + 1, we subtract
848+
# one from all values manually to arrive at the values below.
849+
# TODO: decide if/how we want to incorporate key frame indices into the utils
850+
# framework.
851+
nasa_reference_key_frame_indices = torch.tensor([0, 240])
852+
853+
torch.testing.assert_close(
854+
key_frame_indices, nasa_reference_key_frame_indices, atol=0, rtol=0
855+
)
856+
857+
decoder = VideoDecoder(AV1_VIDEO.path, device=device, seek_mode="exact")
858+
key_frame_indices = decoder._get_key_frame_indices()
859+
860+
# $ ffprobe -v error -hide_banner -select_streams v:0 -show_frames -of csv test/resources/av1_video.mkv | grep -n ",I," | cut -d ':' -f 1 > key_frames.txt
861+
av1_reference_key_frame_indices = torch.tensor([0])
862+
863+
torch.testing.assert_close(
864+
key_frame_indices, av1_reference_key_frame_indices, atol=0, rtol=0
865+
)
866+
867+
decoder = VideoDecoder(H265_VIDEO.path, device=device, seek_mode="exact")
868+
key_frame_indices = decoder._get_key_frame_indices()
869+
870+
# ffprobe -v error -hide_banner -select_streams v:0 -show_frames -of csv test/resources/h265_video.mp4 | grep -n ",I," | cut -d ':' -f 1 > key_frames.txt
871+
h265_reference_key_frame_indices = torch.tensor([0, 2, 4, 6, 8])
872+
873+
torch.testing.assert_close(
874+
key_frame_indices, h265_reference_key_frame_indices, atol=0, rtol=0
875+
)
876+
834877

835878
if __name__ == "__main__":
836879
pytest.main()

0 commit comments

Comments
 (0)