From e5ff90e008d91605afc724d1f76500b3dc885434 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 7 May 2024 16:14:44 +0200 Subject: [PATCH] Fix scan_checkpoints with padded sequences --- pytensor/scan/checkpoints.py | 14 ++++++++------ tests/scan/test_checkpoints.py | 9 ++++++--- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/pytensor/scan/checkpoints.py b/pytensor/scan/checkpoints.py index 36dc2af1fe..8c237267d5 100644 --- a/pytensor/scan/checkpoints.py +++ b/pytensor/scan/checkpoints.py @@ -1,7 +1,7 @@ import pytensor.tensor.basic as ptb from pytensor.scan.basic import scan from pytensor.tensor.basic import Join -from pytensor.tensor.math import ceil, eq +from pytensor.tensor.math import ceil, eq, neq from pytensor.tensor.subtensor import set_subtensor @@ -130,16 +130,18 @@ def scan_checkpoints( # Since padding could be an empty tensor, Join returns a view of s. join = Join(view=0) for i, s in enumerate(sequences): - n = s.shape[0] % save_every_N - z = ptb.zeros((n, s.shape[1:]), dtype=s.dtype) - sequences[i] = join(0, [s, z]) + overshoots_by = s.shape[0] % save_every_N + overshoots = neq(overshoots_by, 0) + n = (save_every_N - overshoots_by) * overshoots + z = ptb.zeros((n, *s.shape[1:]), dtype=s.dtype) + sequences[i] = join(0, s, z) # Establish the input variables of the outer scan o_sequences = [ s.reshape( - [s.shape[0] / save_every_N, save_every_N] + [s.shape[0] // save_every_N, save_every_N] + [s.shape[i] for i in range(1, s.ndim)], - s.ndim + 1, + ndim=s.ndim + 1, ) for s in sequences ] diff --git a/tests/scan/test_checkpoints.py b/tests/scan/test_checkpoints.py index 0e8796ae0e..b30c1582fe 100644 --- a/tests/scan/test_checkpoints.py +++ b/tests/scan/test_checkpoints.py @@ -5,7 +5,7 @@ from pytensor.gradient import grad from pytensor.scan.basic import scan from pytensor.scan.checkpoints import scan_checkpoints -from pytensor.tensor.basic import ones_like +from pytensor.tensor.basic import arange, ones_like from pytensor.tensor.type import iscalar, vector @@ -13,15 +13,18 @@ class TestScanCheckpoint: def setup_method(self): self.k = iscalar("k") self.A = vector("A") + seq = arange(self.k, dtype="float32") + 1 result, _ = scan( - fn=lambda prior_result, A: prior_result * A, + fn=lambda s, prior_result, A: prior_result * A / s, outputs_info=ones_like(self.A), + sequences=[seq], non_sequences=self.A, n_steps=self.k, ) result_check, _ = scan_checkpoints( - fn=lambda prior_result, A: prior_result * A, + fn=lambda s, prior_result, A: prior_result * A / s, outputs_info=ones_like(self.A), + sequences=[seq], non_sequences=self.A, n_steps=self.k, save_every_N=100,