diff --git a/pyproject.toml b/pyproject.toml index 4608eecc..e2fb0117 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dev = [ "numpy", "pytest", "pillow", + "torcheval", ] [tool.usort] diff --git a/test/test_ops.py b/test/test_ops.py index 158e3d08..6ee4b6e8 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -105,15 +105,13 @@ def test_get_frame_at_pts(self, device): frame6, _, _ = get_frame_at_pts(decoder, 6.02) assert_frames_equal(frame6, reference_frame6.to(device)) frame6, _, _ = get_frame_at_pts(decoder, 6.039366) - assert_frames_equal(frame6, reference_frame6.to(device)) + prev_frame_psnr = assert_frames_equal(frame6, reference_frame6.to(device)) # Note that this timestamp is exactly on a frame boundary, so it should # return the next frame since the right boundary of the interval is # open. next_frame, _, _ = get_frame_at_pts(decoder, 6.039367) - if device == "cpu": - # We can only compare exact equality on CPU. - with pytest.raises(AssertionError): - assert_frames_equal(next_frame, reference_frame6.to(device)) + with pytest.raises(AssertionError): + assert_frames_equal(next_frame, reference_frame6.to(device), psnr=prev_frame_psnr) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frame_at_index(self, device): diff --git a/test/test_samplers.py b/test/test_samplers.py index 72ee108e..164bfc6c 100644 --- a/test/test_samplers.py +++ b/test/test_samplers.py @@ -250,11 +250,11 @@ def test_sampling_range( cm = ( contextlib.nullcontext() if assert_all_equal - else pytest.raises(AssertionError, match="Tensor-likes are not") + else pytest.raises(AssertionError, match="low psnr") ) with cm: for clip in clips: - assert_frames_equal(clip.data, clips[0].data) + assert_frames_equal(clip.data, clips[0].data, psnr=float("inf")) @pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) @@ -447,7 +447,7 @@ def test_random_sampler_randomness(sampler): # Call with a different seed, expect different results torch.manual_seed(1) clips_3 = sampler(decoder, num_clips=num_clips) - with pytest.raises(AssertionError, match="Tensor-likes are not"): + with pytest.raises(AssertionError, match="low psnr"): assert_frames_equal(clips_1[0].data, clips_3[0].data) # Make sure we didn't alter the builtin Python RNG diff --git a/test/utils.py b/test/utils.py index e7ce12e5..6c3abb5c 100644 --- a/test/utils.py +++ b/test/utils.py @@ -30,25 +30,23 @@ def get_ffmpeg_major_version(): return int(get_ffmpeg_library_versions()["ffmpeg_version"].split(".")[0]) -# For use with decoded data frames. On CPU Linux, we expect exact, bit-for-bit -# equality. On CUDA Linux, we expect a small tolerance. -# On other platforms (e.g. MacOS), we also allow a small tolerance. FFmpeg does -# not guarantee bit-for-bit equality across systems and architectures, so we -# also cannot. We currently use Linux on x86_64 as our reference system. -def assert_frames_equal(*args, **kwargs): - if sys.platform == "linux": - if args[0].device.type == "cuda": - atol = 2 - if get_ffmpeg_major_version() == 4: - assert_tensor_close_on_at_least( - args[0], args[1], percentage=95, atol=atol - ) - else: - torch.testing.assert_close(*args, **kwargs, atol=atol, rtol=0) - else: - torch.testing.assert_close(*args, **kwargs, atol=0, rtol=0) +# For use with decoded data frames. `psnr` sets the PSNR threshold when +# frames are considered equal. `float("inf")` correspond to bit-to-bit +# identical frames. Function returns calculated psnr value. +def assert_frames_equal(input, other, psnr=40, msg=None): + if torch.allclose(input, other, atol=0, rtol=0): + return float("inf") else: - torch.testing.assert_close(*args, **kwargs, atol=3, rtol=0) + from torcheval.metrics import PeakSignalNoiseRatio + + metric = PeakSignalNoiseRatio() + metric.update(input, other) + m = metric.compute() + message = f"low psnr: {m} < {psnr}" + if (msg): + message += f" ({msg})" + assert m >= psnr, message + return m # Asserts that at least `percentage`% of the values are within the absolute tolerance.