We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
inplace
set_subtensor
1 parent 3e65827 commit c2b8465Copy full SHA for c2b8465
pytensor/link/jax/dispatch/pad.py
@@ -0,0 +1,16 @@
1
+import jax.numpy as jnp
2
+
3
+from pytensor.link.jax.dispatch import jax_funcify
4
+from pytensor.tensor.pad import Pad, allowed_kwargs
5
6
7
+@jax_funcify([Pad])
8
+def jax_funcify_pad(op, **kwargs):
9
+ pad_mode = op.pad_mode
10
+ expected_kwargs = allowed_kwargs[pad_mode]
11
+ mode_kwargs = {kwarg: getattr(op, kwarg) for kwarg in expected_kwargs}
12
13
+ def pad(x, pad_width, pad_mode=pad_mode):
14
+ return jnp.pad(x, pad_width=pad_width, pad_mode=pad_mode, **mode_kwargs)
15
16
+ return pad
0 commit comments