Need a little guidance on using pytensor.scan() #518
Unanswered
mike-lawrence
asked this question in
Q&A
Replies: 1 comment
-
Here is one option import pytensor
import pytensor.tensor as pt
from pytensor.ifelse import ifelse
def step(a, last_a_copied, copy_count):
last_a_copied, copy_count = ifelse(
copy_count >= last_a_copied,
(a, pt.ones_like(copy_count)),
(last_a_copied, copy_count + 1)
)
return last_a_copied, copy_count
a = pt.vector("a", dtype="int64")
copy_count = a.max() + 1 # Or any value that's larger than any in a
last_a_copied = a[0]
[b, _], _ = pytensor.scan(
step,
sequences=[a],
outputs_info=[last_a_copied, copy_count],
)
b.eval({a: [2,1,3,2,4,4,3,3,5,5,2]}) # array([2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5]) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
How would the following function be implemented using scan?
Beta Was this translation helpful? Give feedback.
All reactions