7
7
8
8
import warnings
9
9
from dataclasses import dataclass , field
10
- from typing import Iterator , List , Sequence , Tuple
10
+ from typing import Iterable , Iterator , List , Sequence , Tuple
11
11
12
12
import numpy as np
13
13
from torch .utils .data .sampler import Sampler
@@ -54,7 +54,7 @@ def __post_init__(self) -> None:
54
54
if len (self .images_per_seq_options ) < 1 :
55
55
raise ValueError ("n_per_seq_posibilities list cannot be empty" )
56
56
57
- self .seq_names = list (self .dataset .seq_to_idx . keys ())
57
+ self .seq_names = list (self .dataset .sequence_names ())
58
58
59
59
def __len__ (self ) -> int :
60
60
return self .num_batches
@@ -72,9 +72,7 @@ def _sample_batch(self, batch_idx) -> List[int]:
72
72
if self .sample_consecutive_frames :
73
73
frame_idx = []
74
74
for seq in chosen_seq :
75
- segment_index = self ._build_segment_index (
76
- list (self .dataset .seq_to_idx [seq ]), n_per_seq
77
- )
75
+ segment_index = self ._build_segment_index (seq , n_per_seq )
78
76
79
77
segment , idx = segment_index [np .random .randint (len (segment_index ))]
80
78
if len (segment ) <= n_per_seq :
@@ -86,7 +84,9 @@ def _sample_batch(self, batch_idx) -> List[int]:
86
84
else :
87
85
frame_idx = [
88
86
_capped_random_choice (
89
- self .dataset .seq_to_idx [seq ], n_per_seq , replace = False
87
+ list (self .dataset .sequence_indices_in_order (seq )),
88
+ n_per_seq ,
89
+ replace = False ,
90
90
)
91
91
for seq in chosen_seq
92
92
]
@@ -98,9 +98,7 @@ def _sample_batch(self, batch_idx) -> List[int]:
98
98
)
99
99
return frame_idx
100
100
101
- def _build_segment_index (
102
- self , seq_frame_indices : List [int ], size : int
103
- ) -> List [Tuple [List [int ], int ]]:
101
+ def _build_segment_index (self , seq : str , size : int ) -> List [Tuple [List [int ], int ]]:
104
102
"""
105
103
Returns a list of (segment, index) tuples, one per eligible frame, where
106
104
segment is a list of frame indices in the contiguous segment the frame
@@ -111,16 +109,14 @@ def _build_segment_index(
111
109
self .consecutive_frames_max_gap > 0
112
110
or self .consecutive_frames_max_gap_seconds > 0.0
113
111
):
114
- sequence_timestamps = _sort_frames_by_timestamps_then_numbers (
115
- seq_frame_indices , self .dataset
112
+ segments = self . _split_to_segments (
113
+ self .dataset . sequence_frames_in_order ( seq )
116
114
)
117
- # TODO: use new API to access frame numbers / timestamps
118
- segments = self ._split_to_segments (sequence_timestamps )
119
115
segments = _cull_short_segments (segments , size )
120
116
if not segments :
121
117
raise AssertionError ("Empty segments after culling" )
122
118
else :
123
- segments = [seq_frame_indices ]
119
+ segments = [list ( self . dataset . sequence_indices_in_order ( seq )) ]
124
120
125
121
# build an index of segment for random selection of a pivot frame
126
122
segment_index = [
@@ -130,7 +126,7 @@ def _build_segment_index(
130
126
return segment_index
131
127
132
128
def _split_to_segments (
133
- self , sequence_timestamps : List [Tuple [float , int , int ]]
129
+ self , sequence_timestamps : Iterable [Tuple [float , int , int ]]
134
130
) -> List [List [int ]]:
135
131
if (
136
132
self .consecutive_frames_max_gap <= 0
@@ -144,7 +140,7 @@ def _split_to_segments(
144
140
for ts , no , idx in sequence_timestamps :
145
141
if ts <= 0.0 and no <= last_no :
146
142
raise AssertionError (
147
- "Frames are not ordered in seq_to_idx while timestamps are not given"
143
+ "Sequence frames are not ordered while timestamps are not given"
148
144
)
149
145
150
146
if (
@@ -161,23 +157,6 @@ def _split_to_segments(
161
157
return segments
162
158
163
159
164
- def _sort_frames_by_timestamps_then_numbers (
165
- seq_frame_indices : List [int ], dataset : ImplicitronDatasetBase
166
- ) -> List [Tuple [float , int , int ]]:
167
- """Build the list of triplets (timestamp, frame_no, dataset_idx).
168
- We attempt to first sort by timestamp, then by frame number.
169
- Timestamps are coalesced with 0s.
170
- """
171
- nos_timestamps = dataset .get_frame_numbers_and_timestamps (seq_frame_indices )
172
-
173
- return sorted (
174
- [
175
- (timestamp , frame_no , idx )
176
- for idx , (frame_no , timestamp ) in zip (seq_frame_indices , nos_timestamps )
177
- ]
178
- )
179
-
180
-
181
160
def _cull_short_segments (segments : List [List [int ]], min_size : int ) -> List [List [int ]]:
182
161
lengths = [(len (segment ), segment ) for segment in segments ]
183
162
max_len , longest_segment = max (lengths )
0 commit comments