Skip to content

Commit c2b8465

Browse files
Remove inplace argument to set_subtensor
1 parent 3e65827 commit c2b8465

File tree

1 file changed

+16
-0
lines changed
  • pytensor/link/jax/dispatch

1 file changed

+16
-0
lines changed

pytensor/link/jax/dispatch/pad.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import jax.numpy as jnp
2+
3+
from pytensor.link.jax.dispatch import jax_funcify
4+
from pytensor.tensor.pad import Pad, allowed_kwargs
5+
6+
7+
@jax_funcify([Pad])
8+
def jax_funcify_pad(op, **kwargs):
9+
pad_mode = op.pad_mode
10+
expected_kwargs = allowed_kwargs[pad_mode]
11+
mode_kwargs = {kwarg: getattr(op, kwarg) for kwarg in expected_kwargs}
12+
13+
def pad(x, pad_width, pad_mode=pad_mode):
14+
return jnp.pad(x, pad_width=pad_width, pad_mode=pad_mode, **mode_kwargs)
15+
16+
return pad

0 commit comments

Comments
 (0)