From 49e4187811e4ad84dc46c4b3fd647995c4c7fa38 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 29 Apr 2025 17:13:37 +0100 Subject: [PATCH 1/2] Ensure correct encoding for non-contiguous WF --- src/torchcodec/_core/Encoder.cpp | 4 +--- test/test_ops.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 114e8600..1e0e75c3 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -13,10 +13,8 @@ torch::Tensor validateWf(torch::Tensor wf) { wf.dtype() == torch::kFloat32, "waveform must have float32 dtype, got ", wf.dtype()); - // TODO-ENCODING check contiguity of the input wf to ensure that it is indeed - // planar (fltp). TORCH_CHECK(wf.dim() == 2, "waveform must have 2 dimensions, got ", wf.dim()); - return wf; + return wf.contiguous(); } void validateSampleRate(const AVCodec& avCodec, int sampleRate) { diff --git a/test/test_ops.py b/test/test_ops.py index 0b304bad..0ae6bbc6 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1284,6 +1284,35 @@ def test_encode_to_tensor_long_output(self): torch.testing.assert_close(self.decode(encoded_tensor), samples) + def test_contiguity(self): + num_samples = 10_000 # per channel + contiguous_samples = torch.rand(2, num_samples).contiguous() + assert contiguous_samples.stride() == (num_samples, 1) + + encoded_from_contiguous = encode_audio_to_tensor( + wf=contiguous_samples, + sample_rate=16_000, + format="flac", + bit_rate=44_000, + ) + non_contiguous_samples = contiguous_samples.T.contiguous().T + assert non_contiguous_samples.stride() == (1, 2) + + torch.testing.assert_close( + contiguous_samples, non_contiguous_samples, rtol=0, atol=0 + ) + + encoded_from_non_contiguous = encode_audio_to_tensor( + wf=non_contiguous_samples, + sample_rate=16_000, + format="flac", + bit_rate=44_000, + ) + + torch.testing.assert_close( + encoded_from_contiguous, encoded_from_non_contiguous, rtol=0, atol=0 + ) + if __name__ == "__main__": pytest.main() From 88d51d21f0f46c560c8f4bbddb8d85c1c5a7068e Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 30 Apr 2025 10:34:39 +0100 Subject: [PATCH 2/2] Add comment --- test/test_ops.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/test_ops.py b/test/test_ops.py index d28510a9..ddca330a 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1268,6 +1268,10 @@ def test_encode_to_tensor_long_output(self): torch.testing.assert_close(self.decode(encoded_tensor), samples) def test_contiguity(self): + # Ensure that 2 waveforms with the same values are encoded in the same + # way, regardless of their memory layout. Here we encode 2 equal + # waveforms, one is row-aligned while the other is column-aligned. + num_samples = 10_000 # per channel contiguous_samples = torch.rand(2, num_samples).contiguous() assert contiguous_samples.stride() == (num_samples, 1)