Skip to content

Commit c0bb49b

Browse files
shapovalovfacebook-github-bot
authored andcommitted
API for accessing frames in order in Implicitron dataset.
Summary: We often want to iterate over frames in the sequence in temporal order. This diff provides the API to do that. `seq_to_idx` should probably be considered to have `protected` visibility. Reviewed By: davnov134 Differential Revision: D35012121 fbshipit-source-id: 41896672ec35cd62f3ed4be3aa119efd33adada1
1 parent 05f656c commit c0bb49b

File tree

3 files changed

+56
-37
lines changed

3 files changed

+56
-37
lines changed

pytorch3d/implicitron/dataset/implicitron_dataset.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from typing import (
1919
ClassVar,
2020
Dict,
21+
Iterable,
22+
Iterator,
2123
List,
2224
Optional,
2325
Sequence,
@@ -203,11 +205,11 @@ class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
203205
204206
This means they have a __getitem__ which returns an instance of a FrameData,
205207
which will describe one frame in one sequence.
206-
207-
Members:
208-
seq_to_idx: For each sequence, the indices of its frames.
209208
"""
210209

210+
# Maps sequence name to the sequence's global frame indices.
211+
# It is used for the default implementations of some functions in this class.
212+
# Implementations which override them are free to ignore this member.
211213
seq_to_idx: Dict[str, List[int]] = field(init=False)
212214

213215
def __len__(self) -> int:
@@ -240,6 +242,43 @@ def get_frame_numbers_and_timestamps(
240242
def get_eval_batches(self) -> Optional[List[List[int]]]:
241243
return None
242244

245+
def sequence_names(self) -> Iterable[str]:
246+
"""Returns an iterator over sequence names in the dataset."""
247+
return self.seq_to_idx.keys()
248+
249+
def sequence_frames_in_order(
250+
self, seq_name: str
251+
) -> Iterator[Tuple[float, int, int]]:
252+
"""Returns an iterator over the frame indices in a given sequence.
253+
We attempt to first sort by timestamp (if they are available),
254+
then by frame number.
255+
256+
Args:
257+
seq_name: the name of the sequence.
258+
259+
Returns:
260+
an iterator over triplets `(timestamp, frame_no, dataset_idx)`,
261+
where `frame_no` is the index within the sequence, and
262+
`dataset_idx` is the index within the dataset.
263+
`None` timestamps are replaced with 0s.
264+
"""
265+
seq_frame_indices = self.seq_to_idx[seq_name]
266+
nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)
267+
268+
yield from sorted(
269+
[
270+
(timestamp, frame_no, idx)
271+
for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps)
272+
]
273+
)
274+
275+
def sequence_indices_in_order(self, seq_name: str) -> Iterator[int]:
276+
"""Same as `sequence_frames_in_order` but returns the iterator over
277+
only dataset indices.
278+
"""
279+
for _, _, idx in self.sequence_frames_in_order(seq_name):
280+
yield idx
281+
243282

244283
class FrameAnnotsEntry(TypedDict):
245284
subset: Optional[str]

pytorch3d/implicitron/dataset/scene_batch_sampler.py

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import warnings
99
from dataclasses import dataclass, field
10-
from typing import Iterator, List, Sequence, Tuple
10+
from typing import Iterable, Iterator, List, Sequence, Tuple
1111

1212
import numpy as np
1313
from torch.utils.data.sampler import Sampler
@@ -54,7 +54,7 @@ def __post_init__(self) -> None:
5454
if len(self.images_per_seq_options) < 1:
5555
raise ValueError("n_per_seq_posibilities list cannot be empty")
5656

57-
self.seq_names = list(self.dataset.seq_to_idx.keys())
57+
self.seq_names = list(self.dataset.sequence_names())
5858

5959
def __len__(self) -> int:
6060
return self.num_batches
@@ -72,9 +72,7 @@ def _sample_batch(self, batch_idx) -> List[int]:
7272
if self.sample_consecutive_frames:
7373
frame_idx = []
7474
for seq in chosen_seq:
75-
segment_index = self._build_segment_index(
76-
list(self.dataset.seq_to_idx[seq]), n_per_seq
77-
)
75+
segment_index = self._build_segment_index(seq, n_per_seq)
7876

7977
segment, idx = segment_index[np.random.randint(len(segment_index))]
8078
if len(segment) <= n_per_seq:
@@ -86,7 +84,9 @@ def _sample_batch(self, batch_idx) -> List[int]:
8684
else:
8785
frame_idx = [
8886
_capped_random_choice(
89-
self.dataset.seq_to_idx[seq], n_per_seq, replace=False
87+
list(self.dataset.sequence_indices_in_order(seq)),
88+
n_per_seq,
89+
replace=False,
9090
)
9191
for seq in chosen_seq
9292
]
@@ -98,9 +98,7 @@ def _sample_batch(self, batch_idx) -> List[int]:
9898
)
9999
return frame_idx
100100

101-
def _build_segment_index(
102-
self, seq_frame_indices: List[int], size: int
103-
) -> List[Tuple[List[int], int]]:
101+
def _build_segment_index(self, seq: str, size: int) -> List[Tuple[List[int], int]]:
104102
"""
105103
Returns a list of (segment, index) tuples, one per eligible frame, where
106104
segment is a list of frame indices in the contiguous segment the frame
@@ -111,16 +109,14 @@ def _build_segment_index(
111109
self.consecutive_frames_max_gap > 0
112110
or self.consecutive_frames_max_gap_seconds > 0.0
113111
):
114-
sequence_timestamps = _sort_frames_by_timestamps_then_numbers(
115-
seq_frame_indices, self.dataset
112+
segments = self._split_to_segments(
113+
self.dataset.sequence_frames_in_order(seq)
116114
)
117-
# TODO: use new API to access frame numbers / timestamps
118-
segments = self._split_to_segments(sequence_timestamps)
119115
segments = _cull_short_segments(segments, size)
120116
if not segments:
121117
raise AssertionError("Empty segments after culling")
122118
else:
123-
segments = [seq_frame_indices]
119+
segments = [list(self.dataset.sequence_indices_in_order(seq))]
124120

125121
# build an index of segment for random selection of a pivot frame
126122
segment_index = [
@@ -130,7 +126,7 @@ def _build_segment_index(
130126
return segment_index
131127

132128
def _split_to_segments(
133-
self, sequence_timestamps: List[Tuple[float, int, int]]
129+
self, sequence_timestamps: Iterable[Tuple[float, int, int]]
134130
) -> List[List[int]]:
135131
if (
136132
self.consecutive_frames_max_gap <= 0
@@ -144,7 +140,7 @@ def _split_to_segments(
144140
for ts, no, idx in sequence_timestamps:
145141
if ts <= 0.0 and no <= last_no:
146142
raise AssertionError(
147-
"Frames are not ordered in seq_to_idx while timestamps are not given"
143+
"Sequence frames are not ordered while timestamps are not given"
148144
)
149145

150146
if (
@@ -161,23 +157,6 @@ def _split_to_segments(
161157
return segments
162158

163159

164-
def _sort_frames_by_timestamps_then_numbers(
165-
seq_frame_indices: List[int], dataset: ImplicitronDatasetBase
166-
) -> List[Tuple[float, int, int]]:
167-
"""Build the list of triplets (timestamp, frame_no, dataset_idx).
168-
We attempt to first sort by timestamp, then by frame number.
169-
Timestamps are coalesced with 0s.
170-
"""
171-
nos_timestamps = dataset.get_frame_numbers_and_timestamps(seq_frame_indices)
172-
173-
return sorted(
174-
[
175-
(timestamp, frame_no, idx)
176-
for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps)
177-
]
178-
)
179-
180-
181160
def _cull_short_segments(segments: List[List[int]], min_size: int) -> List[List[int]]:
182161
lengths = [(len(segment), segment) for segment in segments]
183162
max_len, longest_segment = max(lengths)

tests/implicitron/test_batch_sampler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from collections import defaultdict
1010
from dataclasses import dataclass
1111

12+
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDatasetBase
1213
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
1314

1415

@@ -18,7 +19,7 @@ class MockFrameAnnotation:
1819
frame_timestamp: float = 0.0
1920

2021

21-
class MockDataset:
22+
class MockDataset(ImplicitronDatasetBase):
2223
def __init__(self, num_seq, max_frame_gap=1):
2324
"""
2425
Makes a gap of max_frame_gap frame numbers in the middle of each sequence

0 commit comments

Comments
 (0)