-
Notifications
You must be signed in to change notification settings - Fork 39
Add AudioDecoder.get_samples_played_in_range()
public method
#555
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4495150
277fac2
0f9e14d
00bb28d
9a00c91
a7b67d5
02067cf
e5c9831
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -12,7 +12,7 @@ | |||||
|
||||||
|
||||||
def _frame_repr(self): | ||||||
# Utility to replace Frame and FrameBatch __repr__ method. This prints the | ||||||
# Utility to replace __repr__ method of dataclasses below. This prints the | ||||||
# shape of the .data tensor rather than printing the (potentially very long) | ||||||
# data tensor itself. | ||||||
s = self.__class__.__name__ + ":\n" | ||||||
|
@@ -114,3 +114,28 @@ def __len__(self): | |||||
|
||||||
def __repr__(self): | ||||||
return _frame_repr(self) | ||||||
|
||||||
|
||||||
@dataclass | ||||||
class AudioSamples(Iterable): | ||||||
"""Audio samples with associated metadata.""" | ||||||
|
||||||
# TODO-AUDIO: docs | ||||||
data: Tensor | ||||||
pts_seconds: float | ||||||
sample_rate: int | ||||||
|
||||||
def __post_init__(self): | ||||||
# This is called after __init__() when a Frame is created. We can run | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
# input validation checks here. | ||||||
if not self.data.ndim == 2: | ||||||
raise ValueError(f"data must be 2-dimensional, got {self.data.shape = }") | ||||||
self.pts_seconds = float(self.pts_seconds) | ||||||
self.sample_rate = int(self.sample_rate) | ||||||
|
||||||
def __iter__(self) -> Iterator[Union[Tensor, float]]: | ||||||
for field in dataclasses.fields(self): | ||||||
yield getattr(self, field.name) | ||||||
|
||||||
def __repr__(self): | ||||||
return _frame_repr(self) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
|
||
from torch import Tensor | ||
|
||
from torchcodec import AudioSamples | ||
from torchcodec.decoders import _core as core | ||
from torchcodec.decoders._decoder_utils import ( | ||
create_decoder, | ||
|
@@ -37,3 +38,70 @@ def __init__( | |
) = get_and_validate_stream_metadata( | ||
decoder=self._decoder, stream_index=stream_index, media_type="audio" | ||
) | ||
assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy | ||
|
||
def get_samples_played_in_range( | ||
self, start_seconds: float, stop_seconds: Optional[float] = None | ||
) -> AudioSamples: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we do that, we should also mirror that in the video API. We could try to play the same trick that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed, I'll update #150 eventually with more of these desirable default behaviors. That would be a good onboarding PR. |
||
"""TODO-AUDIO docs""" | ||
if stop_seconds is not None and not start_seconds <= stop_seconds: | ||
raise ValueError( | ||
f"Invalid start seconds: {start_seconds}. It must be less than or equal to stop seconds ({stop_seconds})." | ||
) | ||
if not self._begin_stream_seconds <= start_seconds < self._end_stream_seconds: | ||
raise ValueError( | ||
f"Invalid start seconds: {start_seconds}. " | ||
f"It must be greater than or equal to {self._begin_stream_seconds} " | ||
f"and less than or equal to {self._end_stream_seconds}." | ||
) | ||
frames, first_pts = core.get_frames_by_pts_in_range_audio( | ||
self._decoder, | ||
start_seconds=start_seconds, | ||
stop_seconds=stop_seconds, | ||
) | ||
first_pts = first_pts.item() | ||
|
||
# x = frame boundaries | ||
# | ||
# first_pts last_pts | ||
# v v | ||
# ....x..........x..........x...........x..........x..........x..... | ||
# ^ ^ | ||
# start_seconds stop_seconds | ||
# | ||
# We want to return the samples in [start_seconds, stop_seconds). But | ||
# because the core API is based on frames, the `frames` tensor contains | ||
# the samples in [first_pts, last_pts) | ||
# So we do some basic math to figure out the position of the view that | ||
# we'll return. | ||
|
||
# TODO: sample_rate is either the original one from metadata, or the | ||
# user-specified one (NIY) | ||
assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy | ||
sample_rate = self.metadata.sample_rate | ||
|
||
# TODO: metadata's sample_rate should probably not be Optional | ||
assert sample_rate is not None # mypy. | ||
|
||
if first_pts < start_seconds: | ||
offset_beginning = round((start_seconds - first_pts) * sample_rate) | ||
output_pts_seconds = start_seconds | ||
else: | ||
# In normal cases we'll have first_pts <= start_pts, but in some | ||
# edge cases it's possible to have first_pts > start_seconds, | ||
# typically if the stream's first frame's pts isn't exactly 0. | ||
offset_beginning = 0 | ||
output_pts_seconds = first_pts | ||
|
||
num_samples = frames.shape[1] | ||
last_pts = first_pts + num_samples / self.metadata.sample_rate | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Won't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes it's a float. I didn't add the We do call |
||
if stop_seconds is not None and stop_seconds < last_pts: | ||
offset_end = num_samples - round((last_pts - stop_seconds) * sample_rate) | ||
else: | ||
offset_end = num_samples | ||
|
||
return AudioSamples( | ||
data=frames[:, offset_beginning:offset_end], | ||
pts_seconds=output_pts_seconds, | ||
sample_rate=sample_rate, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -854,7 +854,7 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio( | |
|
||
if (startSeconds == stopSeconds) { | ||
// For consistency with video | ||
return AudioFramesOutput{torch::empty({0}), 0.0}; | ||
return AudioFramesOutput{torch::empty({0, 0}), 0.0}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Drive-by, the video APIs return something of shape |
||
} | ||
|
||
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Below: the audio decoding API returns the new
AudioSamples
class rather than a pure Tensor. I think it has the following benefits:pts_seconds
, which may not always be equal tostart_seconds
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this makes sense, and it mirrors the video API. In the video API,
__getitem__()
returns just the tensor, but the named methods returnFrame
orFrameBatch
. I think we should probably do that for audio as well.